feat: Waku v2 bridge

Issue #12610
This commit is contained in:
Michal Iskierko
2023-11-12 13:29:38 +01:00
parent 56e7bd01ca
commit 6d31343205
6716 changed files with 1982502 additions and 5891 deletions

View File

@@ -0,0 +1,232 @@
package backoff
import (
"math"
"math/rand"
"sync"
"time"
logging "github.com/ipfs/go-log/v2"
)
var log = logging.Logger("discovery-backoff")
type BackoffFactory func() BackoffStrategy
// BackoffStrategy describes how backoff will be implemented. BackoffStrategies are stateful.
type BackoffStrategy interface {
// Delay calculates how long the next backoff duration should be, given the prior calls to Delay
Delay() time.Duration
// Reset clears the internal state of the BackoffStrategy
Reset()
}
// Jitter implementations taken roughly from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
// Jitter must return a duration between min and max. Min must be lower than, or equal to, max.
type Jitter func(duration, min, max time.Duration, rng *rand.Rand) time.Duration
// FullJitter returns a random number, uniformly chosen from the range [min, boundedDur].
// boundedDur is the duration bounded between min and max.
func FullJitter(duration, min, max time.Duration, rng *rand.Rand) time.Duration {
if duration <= min {
return min
}
normalizedDur := boundedDuration(duration, min, max) - min
return boundedDuration(time.Duration(rng.Int63n(int64(normalizedDur)))+min, min, max)
}
// NoJitter returns the duration bounded between min and max
func NoJitter(duration, min, max time.Duration, rng *rand.Rand) time.Duration {
return boundedDuration(duration, min, max)
}
type randomizedBackoff struct {
min time.Duration
max time.Duration
rng *rand.Rand
}
func (b *randomizedBackoff) BoundedDelay(duration time.Duration) time.Duration {
return boundedDuration(duration, b.min, b.max)
}
func boundedDuration(d, min, max time.Duration) time.Duration {
if d < min {
return min
}
if d > max {
return max
}
return d
}
type attemptBackoff struct {
attempt int
jitter Jitter
randomizedBackoff
}
func (b *attemptBackoff) Reset() {
b.attempt = 0
}
// NewFixedBackoff creates a BackoffFactory with a constant backoff duration
func NewFixedBackoff(delay time.Duration) BackoffFactory {
return func() BackoffStrategy {
return &fixedBackoff{delay: delay}
}
}
type fixedBackoff struct {
delay time.Duration
}
func (b *fixedBackoff) Delay() time.Duration {
return b.delay
}
func (b *fixedBackoff) Reset() {}
// NewPolynomialBackoff creates a BackoffFactory with backoff of the form c0*x^0, c1*x^1, ...cn*x^n where x is the attempt number
// jitter is the function for adding randomness around the backoff
// timeUnits are the units of time the polynomial is evaluated in
// polyCoefs is the array of polynomial coefficients from [c0, c1, ... cn]
func NewPolynomialBackoff(min, max time.Duration, jitter Jitter,
timeUnits time.Duration, polyCoefs []float64, rngSrc rand.Source) BackoffFactory {
rng := rand.New(&lockedSource{src: rngSrc})
return func() BackoffStrategy {
return &polynomialBackoff{
attemptBackoff: attemptBackoff{
randomizedBackoff: randomizedBackoff{
min: min,
max: max,
rng: rng,
},
jitter: jitter,
},
timeUnits: timeUnits,
poly: polyCoefs,
}
}
}
type polynomialBackoff struct {
attemptBackoff
timeUnits time.Duration
poly []float64
}
func (b *polynomialBackoff) Delay() time.Duration {
var polySum float64
switch len(b.poly) {
case 0:
return 0
case 1:
polySum = b.poly[0]
default:
polySum = b.poly[0]
exp := 1
attempt := b.attempt
b.attempt++
for _, c := range b.poly[1:] {
exp *= attempt
polySum += float64(exp) * c
}
}
return b.jitter(time.Duration(float64(b.timeUnits)*polySum), b.min, b.max, b.rng)
}
// NewExponentialBackoff creates a BackoffFactory with backoff of the form base^x + offset where x is the attempt number
// jitter is the function for adding randomness around the backoff
// timeUnits are the units of time the base^x is evaluated in
func NewExponentialBackoff(min, max time.Duration, jitter Jitter,
timeUnits time.Duration, base float64, offset time.Duration, rngSrc rand.Source) BackoffFactory {
rng := rand.New(&lockedSource{src: rngSrc})
return func() BackoffStrategy {
return &exponentialBackoff{
attemptBackoff: attemptBackoff{
randomizedBackoff: randomizedBackoff{
min: min,
max: max,
rng: rng,
},
jitter: jitter,
},
timeUnits: timeUnits,
base: base,
offset: offset,
}
}
}
type exponentialBackoff struct {
attemptBackoff
timeUnits time.Duration
base float64
offset time.Duration
}
func (b *exponentialBackoff) Delay() time.Duration {
attempt := b.attempt
b.attempt++
return b.jitter(
time.Duration(math.Pow(b.base, float64(attempt))*float64(b.timeUnits))+b.offset, b.min, b.max, b.rng)
}
// NewExponentialDecorrelatedJitter creates a BackoffFactory with backoff of the roughly of the form base^x where x is the attempt number.
// Delays start at the minimum duration and after each attempt delay = rand(min, delay * base), bounded by the max
// See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ for more information
func NewExponentialDecorrelatedJitter(min, max time.Duration, base float64, rngSrc rand.Source) BackoffFactory {
rng := rand.New(&lockedSource{src: rngSrc})
return func() BackoffStrategy {
return &exponentialDecorrelatedJitter{
randomizedBackoff: randomizedBackoff{
min: min,
max: max,
rng: rng,
},
base: base,
}
}
}
type exponentialDecorrelatedJitter struct {
randomizedBackoff
base float64
lastDelay time.Duration
}
func (b *exponentialDecorrelatedJitter) Delay() time.Duration {
if b.lastDelay < b.min {
b.lastDelay = b.min
return b.lastDelay
}
nextMax := int64(float64(b.lastDelay) * b.base)
b.lastDelay = boundedDuration(time.Duration(b.rng.Int63n(nextMax-int64(b.min)))+b.min, b.min, b.max)
return b.lastDelay
}
func (b *exponentialDecorrelatedJitter) Reset() { b.lastDelay = 0 }
type lockedSource struct {
lk sync.Mutex
src rand.Source
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.src.Int63()
r.lk.Unlock()
return
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.src.Seed(seed)
r.lk.Unlock()
}

View File

@@ -0,0 +1,337 @@
package backoff
import (
"context"
"fmt"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
)
// BackoffDiscovery is an implementation of discovery that caches peer data and attenuates repeated queries
type BackoffDiscovery struct {
disc discovery.Discovery
stratFactory BackoffFactory
peerCache map[string]*backoffCache
peerCacheMux sync.RWMutex
parallelBufSz int
returnedBufSz int
clock clock
}
type BackoffDiscoveryOption func(*BackoffDiscovery) error
func NewBackoffDiscovery(disc discovery.Discovery, stratFactory BackoffFactory, opts ...BackoffDiscoveryOption) (discovery.Discovery, error) {
b := &BackoffDiscovery{
disc: disc,
stratFactory: stratFactory,
peerCache: make(map[string]*backoffCache),
parallelBufSz: 32,
returnedBufSz: 32,
clock: realClock{},
}
for _, opt := range opts {
if err := opt(b); err != nil {
return nil, err
}
}
return b, nil
}
// WithBackoffDiscoverySimultaneousQueryBufferSize sets the buffer size for the channels between the main FindPeers query
// for a given namespace and all simultaneous FindPeers queries for the namespace
func WithBackoffDiscoverySimultaneousQueryBufferSize(size int) BackoffDiscoveryOption {
return func(b *BackoffDiscovery) error {
if size < 0 {
return fmt.Errorf("cannot set size to be smaller than 0")
}
b.parallelBufSz = size
return nil
}
}
// WithBackoffDiscoveryReturnedChannelSize sets the size of the buffer to be used during a FindPeer query.
// Note: This does not apply if the query occurs during the backoff time
func WithBackoffDiscoveryReturnedChannelSize(size int) BackoffDiscoveryOption {
return func(b *BackoffDiscovery) error {
if size < 0 {
return fmt.Errorf("cannot set size to be smaller than 0")
}
b.returnedBufSz = size
return nil
}
}
type clock interface {
Now() time.Time
}
type realClock struct{}
func (c realClock) Now() time.Time {
return time.Now()
}
// withClock lets you override the default time.Now() call. Useful for tests.
func withClock(c clock) BackoffDiscoveryOption {
return func(b *BackoffDiscovery) error {
b.clock = c
return nil
}
}
type backoffCache struct {
// strat is assigned on creation and not written to
strat BackoffStrategy
mux sync.Mutex // guards writes to all following fields
nextDiscover time.Time
prevPeers map[peer.ID]peer.AddrInfo
peers map[peer.ID]peer.AddrInfo
sendingChs map[chan peer.AddrInfo]int
ongoing bool
clock clock
}
func (d *BackoffDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) {
return d.disc.Advertise(ctx, ns, opts...)
}
func (d *BackoffDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) {
// Get options
var options discovery.Options
err := options.Apply(opts...)
if err != nil {
return nil, err
}
// Get cached peers
d.peerCacheMux.RLock()
c, ok := d.peerCache[ns]
d.peerCacheMux.RUnlock()
/*
Overall plan:
If it's time to look for peers, look for peers, then return them
If it's not time then return cache
If it's time to look for peers, but we have already started looking. Get up to speed with ongoing request
*/
// Setup cache if we don't have one yet
if !ok {
pc := &backoffCache{
nextDiscover: time.Time{},
prevPeers: make(map[peer.ID]peer.AddrInfo),
peers: make(map[peer.ID]peer.AddrInfo),
sendingChs: make(map[chan peer.AddrInfo]int),
strat: d.stratFactory(),
clock: d.clock,
}
d.peerCacheMux.Lock()
c, ok = d.peerCache[ns]
if !ok {
d.peerCache[ns] = pc
c = pc
}
d.peerCacheMux.Unlock()
}
c.mux.Lock()
defer c.mux.Unlock()
timeExpired := d.clock.Now().After(c.nextDiscover)
// If it's not yet time to search again and no searches are in progress then return cached peers
if !(timeExpired || c.ongoing) {
chLen := options.Limit
if chLen == 0 {
chLen = len(c.prevPeers)
} else if chLen > len(c.prevPeers) {
chLen = len(c.prevPeers)
}
pch := make(chan peer.AddrInfo, chLen)
for _, ai := range c.prevPeers {
select {
case pch <- ai:
default:
// skip if we have asked for a lower limit than the number of peers known
}
}
close(pch)
return pch, nil
}
// If a request is not already in progress setup a dispatcher channel for dispatching incoming peers
if !c.ongoing {
pch, err := d.disc.FindPeers(ctx, ns, opts...)
if err != nil {
return nil, err
}
c.ongoing = true
go findPeerDispatcher(ctx, c, pch)
}
// Setup receiver channel for receiving peers from ongoing requests
evtCh := make(chan peer.AddrInfo, d.parallelBufSz)
pch := make(chan peer.AddrInfo, d.returnedBufSz)
rcvPeers := make([]peer.AddrInfo, 0, 32)
for _, ai := range c.peers {
rcvPeers = append(rcvPeers, ai)
}
c.sendingChs[evtCh] = options.Limit
go findPeerReceiver(ctx, pch, evtCh, rcvPeers)
return pch, nil
}
func findPeerDispatcher(ctx context.Context, c *backoffCache, pch <-chan peer.AddrInfo) {
defer func() {
c.mux.Lock()
// If the peer addresses have changed reset the backoff
if checkUpdates(c.prevPeers, c.peers) {
c.strat.Reset()
c.prevPeers = c.peers
}
c.nextDiscover = c.clock.Now().Add(c.strat.Delay())
c.ongoing = false
c.peers = make(map[peer.ID]peer.AddrInfo)
for ch := range c.sendingChs {
close(ch)
}
c.sendingChs = make(map[chan peer.AddrInfo]int)
c.mux.Unlock()
}()
for {
select {
case ai, ok := <-pch:
if !ok {
return
}
c.mux.Lock()
// If we receive the same peer multiple times return the address union
var sendAi peer.AddrInfo
if prevAi, ok := c.peers[ai.ID]; ok {
if combinedAi := mergeAddrInfos(prevAi, ai); combinedAi != nil {
sendAi = *combinedAi
} else {
c.mux.Unlock()
continue
}
} else {
sendAi = ai
}
c.peers[ai.ID] = sendAi
for ch, rem := range c.sendingChs {
if rem > 0 {
ch <- sendAi
c.sendingChs[ch] = rem - 1
}
}
c.mux.Unlock()
case <-ctx.Done():
return
}
}
}
func findPeerReceiver(ctx context.Context, pch, evtCh chan peer.AddrInfo, rcvPeers []peer.AddrInfo) {
defer close(pch)
for {
select {
case ai, ok := <-evtCh:
if ok {
rcvPeers = append(rcvPeers, ai)
sentAll := true
sendPeers:
for i, p := range rcvPeers {
select {
case pch <- p:
default:
rcvPeers = rcvPeers[i:]
sentAll = false
break sendPeers
}
}
if sentAll {
rcvPeers = []peer.AddrInfo{}
}
} else {
for _, p := range rcvPeers {
select {
case pch <- p:
case <-ctx.Done():
return
}
}
return
}
case <-ctx.Done():
return
}
}
}
func mergeAddrInfos(prevAi, newAi peer.AddrInfo) *peer.AddrInfo {
seen := make(map[string]struct{}, len(prevAi.Addrs))
combinedAddrs := make([]ma.Multiaddr, 0, len(prevAi.Addrs))
addAddrs := func(addrs []ma.Multiaddr) {
for _, addr := range addrs {
if _, ok := seen[addr.String()]; ok {
continue
}
seen[addr.String()] = struct{}{}
combinedAddrs = append(combinedAddrs, addr)
}
}
addAddrs(prevAi.Addrs)
addAddrs(newAi.Addrs)
if len(combinedAddrs) > len(prevAi.Addrs) {
combinedAi := &peer.AddrInfo{ID: prevAi.ID, Addrs: combinedAddrs}
return combinedAi
}
return nil
}
func checkUpdates(orig, update map[peer.ID]peer.AddrInfo) bool {
if len(orig) != len(update) {
return true
}
for p, ai := range update {
if prevAi, ok := orig[p]; ok {
if combinedAi := mergeAddrInfos(prevAi, ai); combinedAi != nil {
return true
}
} else {
return true
}
}
return false
}

View File

@@ -0,0 +1,94 @@
package backoff
import (
"context"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
lru "github.com/hashicorp/golang-lru/v2"
)
// BackoffConnector is a utility to connect to peers, but only if we have not recently tried connecting to them already
type BackoffConnector struct {
cache *lru.TwoQueueCache[peer.ID, *connCacheData]
host host.Host
connTryDur time.Duration
backoff BackoffFactory
mux sync.Mutex
}
// NewBackoffConnector creates a utility to connect to peers, but only if we have not recently tried connecting to them already
// cacheSize is the size of a TwoQueueCache
// connectionTryDuration is how long we attempt to connect to a peer before giving up
// backoff describes the strategy used to decide how long to backoff after previously attempting to connect to a peer
func NewBackoffConnector(h host.Host, cacheSize int, connectionTryDuration time.Duration, backoff BackoffFactory) (*BackoffConnector, error) {
cache, err := lru.New2Q[peer.ID, *connCacheData](cacheSize)
if err != nil {
return nil, err
}
return &BackoffConnector{
cache: cache,
host: h,
connTryDur: connectionTryDuration,
backoff: backoff,
}, nil
}
type connCacheData struct {
nextTry time.Time
strat BackoffStrategy
}
// Connect attempts to connect to the peers passed in by peerCh. Will not connect to peers if they are within the backoff period.
// As Connect will attempt to dial peers as soon as it learns about them, the caller should try to keep the number,
// and rate, of inbound peers manageable.
func (c *BackoffConnector) Connect(ctx context.Context, peerCh <-chan peer.AddrInfo) {
for {
select {
case pi, ok := <-peerCh:
if !ok {
return
}
if pi.ID == c.host.ID() || pi.ID == "" {
continue
}
c.mux.Lock()
var cachedPeer *connCacheData
if tv, ok := c.cache.Get(pi.ID); ok {
now := time.Now()
if now.Before(tv.nextTry) {
c.mux.Unlock()
continue
}
tv.nextTry = now.Add(tv.strat.Delay())
} else {
cachedPeer = &connCacheData{strat: c.backoff()}
cachedPeer.nextTry = time.Now().Add(cachedPeer.strat.Delay())
c.cache.Add(pi.ID, cachedPeer)
}
c.mux.Unlock()
go func(pi peer.AddrInfo) {
ctx, cancel := context.WithTimeout(ctx, c.connTryDur)
defer cancel()
err := c.host.Connect(ctx, pi)
if err != nil {
log.Debugf("Error connecting to pubsub peer %s: %s", pi.ID, err.Error())
return
}
}(pi)
case <-ctx.Done():
log.Infof("discovery: backoff connector context error %v", ctx.Err())
return
}
}
}

View File

@@ -0,0 +1,450 @@
package autonat
import (
"context"
"math/rand"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
var log = logging.Logger("autonat")
const maxConfidence = 3
// AmbientAutoNAT is the implementation of ambient NAT autodiscovery
type AmbientAutoNAT struct {
host host.Host
*config
ctx context.Context
ctxCancel context.CancelFunc // is closed when Close is called
backgroundRunning chan struct{} // is closed when the background go routine exits
inboundConn chan network.Conn
dialResponses chan error
// status is an autoNATResult reflecting current status.
status atomic.Pointer[network.Reachability]
// Reflects the confidence on of the NATStatus being private, as a single
// dialback may fail for reasons unrelated to NAT.
// If it is <3, then multiple autoNAT peers may be contacted for dialback
// If only a single autoNAT peer is known, then the confidence increases
// for each failure until it reaches 3.
confidence int
lastInbound time.Time
lastProbeTry time.Time
lastProbe time.Time
recentProbes map[peer.ID]time.Time
service *autoNATService
emitReachabilityChanged event.Emitter
subscriber event.Subscription
}
// StaticAutoNAT is a simple AutoNAT implementation when a single NAT status is desired.
type StaticAutoNAT struct {
host host.Host
reachability network.Reachability
service *autoNATService
}
// New creates a new NAT autodiscovery system attached to a host
func New(h host.Host, options ...Option) (AutoNAT, error) {
var err error
conf := new(config)
conf.host = h
conf.dialPolicy.host = h
if err = defaults(conf); err != nil {
return nil, err
}
if conf.addressFunc == nil {
conf.addressFunc = h.Addrs
}
for _, o := range options {
if err = o(conf); err != nil {
return nil, err
}
}
emitReachabilityChanged, _ := h.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
var service *autoNATService
if (!conf.forceReachability || conf.reachability == network.ReachabilityPublic) && conf.dialer != nil {
service, err = newAutoNATService(conf)
if err != nil {
return nil, err
}
service.Enable()
}
if conf.forceReachability {
emitReachabilityChanged.Emit(event.EvtLocalReachabilityChanged{Reachability: conf.reachability})
return &StaticAutoNAT{
host: h,
reachability: conf.reachability,
service: service,
}, nil
}
ctx, cancel := context.WithCancel(context.Background())
as := &AmbientAutoNAT{
ctx: ctx,
ctxCancel: cancel,
backgroundRunning: make(chan struct{}),
host: h,
config: conf,
inboundConn: make(chan network.Conn, 5),
dialResponses: make(chan error, 1),
emitReachabilityChanged: emitReachabilityChanged,
service: service,
recentProbes: make(map[peer.ID]time.Time),
}
reachability := network.ReachabilityUnknown
as.status.Store(&reachability)
subscriber, err := as.host.EventBus().Subscribe(
[]any{new(event.EvtLocalAddressesUpdated), new(event.EvtPeerIdentificationCompleted)},
eventbus.Name("autonat"),
)
if err != nil {
return nil, err
}
as.subscriber = subscriber
h.Network().Notify(as)
go as.background()
return as, nil
}
// Status returns the AutoNAT observed reachability status.
func (as *AmbientAutoNAT) Status() network.Reachability {
s := as.status.Load()
return *s
}
func (as *AmbientAutoNAT) emitStatus() {
status := *as.status.Load()
as.emitReachabilityChanged.Emit(event.EvtLocalReachabilityChanged{Reachability: status})
if as.metricsTracer != nil {
as.metricsTracer.ReachabilityStatus(status)
}
}
func ipInList(candidate ma.Multiaddr, list []ma.Multiaddr) bool {
candidateIP, _ := manet.ToIP(candidate)
for _, i := range list {
if ip, err := manet.ToIP(i); err == nil && ip.Equal(candidateIP) {
return true
}
}
return false
}
func (as *AmbientAutoNAT) background() {
defer close(as.backgroundRunning)
// wait a bit for the node to come online and establish some connections
// before starting autodetection
delay := as.config.bootDelay
subChan := as.subscriber.Out()
defer as.subscriber.Close()
defer as.emitReachabilityChanged.Close()
timer := time.NewTimer(delay)
defer timer.Stop()
timerRunning := true
retryProbe := false
for {
select {
// new inbound connection.
case conn := <-as.inboundConn:
localAddrs := as.host.Addrs()
if manet.IsPublicAddr(conn.RemoteMultiaddr()) &&
!ipInList(conn.RemoteMultiaddr(), localAddrs) {
as.lastInbound = time.Now()
}
case e := <-subChan:
switch e := e.(type) {
case event.EvtLocalAddressesUpdated:
// On local address update, reduce confidence from maximum so that we schedule
// the next probe sooner
if as.confidence == maxConfidence {
as.confidence--
}
case event.EvtPeerIdentificationCompleted:
if s, err := as.host.Peerstore().SupportsProtocols(e.Peer, AutoNATProto); err == nil && len(s) > 0 {
currentStatus := *as.status.Load()
if currentStatus == network.ReachabilityUnknown {
as.tryProbe(e.Peer)
}
}
default:
log.Errorf("unknown event type: %T", e)
}
// probe finished.
case err, ok := <-as.dialResponses:
if !ok {
return
}
if IsDialRefused(err) {
retryProbe = true
} else {
as.handleDialResponse(err)
}
case <-timer.C:
peer := as.getPeerToProbe()
as.tryProbe(peer)
timerRunning = false
retryProbe = false
case <-as.ctx.Done():
return
}
// Drain the timer channel if it hasn't fired in preparation for Resetting it.
if timerRunning && !timer.Stop() {
<-timer.C
}
timer.Reset(as.scheduleProbe(retryProbe))
timerRunning = true
}
}
func (as *AmbientAutoNAT) cleanupRecentProbes() {
fixedNow := time.Now()
for k, v := range as.recentProbes {
if fixedNow.Sub(v) > as.throttlePeerPeriod {
delete(as.recentProbes, k)
}
}
}
// scheduleProbe calculates when the next probe should be scheduled for.
func (as *AmbientAutoNAT) scheduleProbe(retryProbe bool) time.Duration {
// Our baseline is a probe every 'AutoNATRefreshInterval'
// This is modulated by:
// * if we are in an unknown state, have low confidence, or we want to retry because a probe was refused that
// should drop to 'AutoNATRetryInterval'
// * recent inbound connections (implying continued connectivity) should decrease the retry when public
// * recent inbound connections when not public mean we should try more actively to see if we're public.
fixedNow := time.Now()
currentStatus := *as.status.Load()
nextProbe := fixedNow
// Don't look for peers in the peer store more than once per second.
if !as.lastProbeTry.IsZero() {
backoff := as.lastProbeTry.Add(time.Second)
if backoff.After(nextProbe) {
nextProbe = backoff
}
}
if !as.lastProbe.IsZero() {
untilNext := as.config.refreshInterval
if retryProbe {
untilNext = as.config.retryInterval
} else if currentStatus == network.ReachabilityUnknown {
untilNext = as.config.retryInterval
} else if as.confidence < maxConfidence {
untilNext = as.config.retryInterval
} else if currentStatus == network.ReachabilityPublic && as.lastInbound.After(as.lastProbe) {
untilNext *= 2
} else if currentStatus != network.ReachabilityPublic && as.lastInbound.After(as.lastProbe) {
untilNext /= 5
}
if as.lastProbe.Add(untilNext).After(nextProbe) {
nextProbe = as.lastProbe.Add(untilNext)
}
}
if as.metricsTracer != nil {
as.metricsTracer.NextProbeTime(nextProbe)
}
return nextProbe.Sub(fixedNow)
}
// handleDialResponse updates the current status based on dial response.
func (as *AmbientAutoNAT) handleDialResponse(dialErr error) {
var observation network.Reachability
switch {
case dialErr == nil:
observation = network.ReachabilityPublic
case IsDialError(dialErr):
observation = network.ReachabilityPrivate
default:
observation = network.ReachabilityUnknown
}
as.recordObservation(observation)
}
// recordObservation updates NAT status and confidence
func (as *AmbientAutoNAT) recordObservation(observation network.Reachability) {
currentStatus := *as.status.Load()
if observation == network.ReachabilityPublic {
changed := false
if currentStatus != network.ReachabilityPublic {
// Aggressively switch to public from other states ignoring confidence
log.Debugf("NAT status is public")
// we are flipping our NATStatus, so confidence drops to 0
as.confidence = 0
if as.service != nil {
as.service.Enable()
}
changed = true
} else if as.confidence < maxConfidence {
as.confidence++
}
as.status.Store(&observation)
if changed {
as.emitStatus()
}
} else if observation == network.ReachabilityPrivate {
if currentStatus != network.ReachabilityPrivate {
if as.confidence > 0 {
as.confidence--
} else {
log.Debugf("NAT status is private")
// we are flipping our NATStatus, so confidence drops to 0
as.confidence = 0
as.status.Store(&observation)
if as.service != nil {
as.service.Disable()
}
as.emitStatus()
}
} else if as.confidence < maxConfidence {
as.confidence++
as.status.Store(&observation)
}
} else if as.confidence > 0 {
// don't just flip to unknown, reduce confidence first
as.confidence--
} else {
log.Debugf("NAT status is unknown")
as.status.Store(&observation)
if currentStatus != network.ReachabilityUnknown {
if as.service != nil {
as.service.Enable()
}
as.emitStatus()
}
}
if as.metricsTracer != nil {
as.metricsTracer.ReachabilityStatusConfidence(as.confidence)
}
}
func (as *AmbientAutoNAT) tryProbe(p peer.ID) bool {
as.lastProbeTry = time.Now()
if p.Validate() != nil {
return false
}
if lastTime, ok := as.recentProbes[p]; ok {
if time.Since(lastTime) < as.throttlePeerPeriod {
return false
}
}
as.cleanupRecentProbes()
info := as.host.Peerstore().PeerInfo(p)
if !as.config.dialPolicy.skipPeer(info.Addrs) {
as.recentProbes[p] = time.Now()
as.lastProbe = time.Now()
go as.probe(&info)
return true
}
return false
}
func (as *AmbientAutoNAT) probe(pi *peer.AddrInfo) {
cli := NewAutoNATClient(as.host, as.config.addressFunc, as.metricsTracer)
ctx, cancel := context.WithTimeout(as.ctx, as.config.requestTimeout)
defer cancel()
err := cli.DialBack(ctx, pi.ID)
log.Debugf("Dialback through peer %s completed: err: %s", pi.ID, err)
select {
case as.dialResponses <- err:
case <-as.ctx.Done():
return
}
}
func (as *AmbientAutoNAT) getPeerToProbe() peer.ID {
peers := as.host.Network().Peers()
if len(peers) == 0 {
return ""
}
candidates := make([]peer.ID, 0, len(peers))
for _, p := range peers {
info := as.host.Peerstore().PeerInfo(p)
// Exclude peers which don't support the autonat protocol.
if proto, err := as.host.Peerstore().SupportsProtocols(p, AutoNATProto); len(proto) == 0 || err != nil {
continue
}
// Exclude peers in backoff.
if lastTime, ok := as.recentProbes[p]; ok {
if time.Since(lastTime) < as.throttlePeerPeriod {
continue
}
}
if as.config.dialPolicy.skipPeer(info.Addrs) {
continue
}
candidates = append(candidates, p)
}
if len(candidates) == 0 {
return ""
}
return candidates[rand.Intn(len(candidates))]
}
func (as *AmbientAutoNAT) Close() error {
as.ctxCancel()
if as.service != nil {
as.service.Disable()
}
<-as.backgroundRunning
return nil
}
// Status returns the AutoNAT observed reachability status.
func (s *StaticAutoNAT) Status() network.Reachability {
return s.reachability
}
func (s *StaticAutoNAT) Close() error {
if s.service != nil {
s.service.Disable()
}
return nil
}

View File

@@ -0,0 +1,122 @@
package autonat
import (
"context"
"fmt"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/host/autonat/pb"
"github.com/libp2p/go-msgio/pbio"
)
// NewAutoNATClient creates a fresh instance of an AutoNATClient
// If addrFunc is nil, h.Addrs will be used
func NewAutoNATClient(h host.Host, addrFunc AddrFunc, mt MetricsTracer) Client {
if addrFunc == nil {
addrFunc = h.Addrs
}
return &client{h: h, addrFunc: addrFunc, mt: mt}
}
type client struct {
h host.Host
addrFunc AddrFunc
mt MetricsTracer
}
// DialBack asks peer p to dial us back on all addresses returned by the addrFunc.
// It blocks until we've received a response from the peer.
//
// Note: A returned error Message_E_DIAL_ERROR does not imply that the server
// actually performed a dial attempt. Servers that run a version < v0.20.0 also
// return Message_E_DIAL_ERROR if the dial was skipped due to the dialPolicy.
func (c *client) DialBack(ctx context.Context, p peer.ID) error {
s, err := c.h.NewStream(ctx, p, AutoNATProto)
if err != nil {
return err
}
if err := s.Scope().SetService(ServiceName); err != nil {
log.Debugf("error attaching stream to autonat service: %s", err)
s.Reset()
return err
}
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
log.Debugf("error reserving memory for autonat stream: %s", err)
s.Reset()
return err
}
defer s.Scope().ReleaseMemory(maxMsgSize)
s.SetDeadline(time.Now().Add(streamTimeout))
// Might as well just reset the stream. Once we get to this point, we
// don't care about being nice.
defer s.Close()
r := pbio.NewDelimitedReader(s, maxMsgSize)
w := pbio.NewDelimitedWriter(s)
req := newDialMessage(peer.AddrInfo{ID: c.h.ID(), Addrs: c.addrFunc()})
if err := w.WriteMsg(req); err != nil {
s.Reset()
return err
}
var res pb.Message
if err := r.ReadMsg(&res); err != nil {
s.Reset()
return err
}
if res.GetType() != pb.Message_DIAL_RESPONSE {
s.Reset()
return fmt.Errorf("unexpected response: %s", res.GetType().String())
}
status := res.GetDialResponse().GetStatus()
if c.mt != nil {
c.mt.ReceivedDialResponse(status)
}
switch status {
case pb.Message_OK:
return nil
default:
return Error{Status: status, Text: res.GetDialResponse().GetStatusText()}
}
}
// Error wraps errors signalled by AutoNAT services
type Error struct {
Status pb.Message_ResponseStatus
Text string
}
func (e Error) Error() string {
return fmt.Sprintf("AutoNAT error: %s (%s)", e.Text, e.Status.String())
}
// IsDialError returns true if the error was due to a dial back failure
func (e Error) IsDialError() bool {
return e.Status == pb.Message_E_DIAL_ERROR
}
// IsDialRefused returns true if the error was due to a refusal to dial back
func (e Error) IsDialRefused() bool {
return e.Status == pb.Message_E_DIAL_REFUSED
}
// IsDialError returns true if the AutoNAT peer signalled an error dialing back
func IsDialError(e error) bool {
ae, ok := e.(Error)
return ok && ae.IsDialError()
}
// IsDialRefused returns true if the AutoNAT peer signalled refusal to dial back
func IsDialRefused(e error) bool {
ae, ok := e.(Error)
return ok && ae.IsDialRefused()
}

View File

@@ -0,0 +1,95 @@
package autonat
import (
"net"
"github.com/libp2p/go-libp2p/core/host"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type dialPolicy struct {
allowSelfDials bool
host host.Host
}
// skipDial indicates that a multiaddress isn't worth attempted dialing.
// The same logic is used when the autonat client is considering if
// a remote peer is worth using as a server, and when the server is
// considering if a requested client is worth dialing back.
func (d *dialPolicy) skipDial(addr ma.Multiaddr) bool {
// skip relay addresses
_, err := addr.ValueForProtocol(ma.P_CIRCUIT)
if err == nil {
return true
}
if d.allowSelfDials {
return false
}
// skip private network (unroutable) addresses
if !manet.IsPublicAddr(addr) {
return true
}
candidateIP, err := manet.ToIP(addr)
if err != nil {
return true
}
// Skip dialing addresses we believe are the local node's
for _, localAddr := range d.host.Addrs() {
localIP, err := manet.ToIP(localAddr)
if err != nil {
continue
}
if localIP.Equal(candidateIP) {
return true
}
}
return false
}
// skipPeer indicates that the collection of multiaddresses representing a peer
// isn't worth attempted dialing. If one of the addresses matches an address
// we believe is ours, we exclude the peer, even if there are other valid
// public addresses in the list.
func (d *dialPolicy) skipPeer(addrs []ma.Multiaddr) bool {
localAddrs := d.host.Addrs()
localHosts := make([]net.IP, 0)
for _, lAddr := range localAddrs {
if _, err := lAddr.ValueForProtocol(ma.P_CIRCUIT); err != nil && manet.IsPublicAddr(lAddr) {
lIP, err := manet.ToIP(lAddr)
if err != nil {
continue
}
localHosts = append(localHosts, lIP)
}
}
// if a public IP of the peer is one of ours: skip the peer.
goodPublic := false
for _, addr := range addrs {
if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err != nil && manet.IsPublicAddr(addr) {
aIP, err := manet.ToIP(addr)
if err != nil {
continue
}
for _, lIP := range localHosts {
if lIP.Equal(aIP) {
return true
}
}
goodPublic = true
}
}
if d.allowSelfDials {
return false
}
return !goodPublic
}

View File

@@ -0,0 +1,31 @@
package autonat
import (
"context"
"io"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
)
// AutoNAT is the interface for NAT autodiscovery
type AutoNAT interface {
// Status returns the current NAT status
Status() network.Reachability
io.Closer
}
// Client is a stateless client interface to AutoNAT peers
type Client interface {
// DialBack requests from a peer providing AutoNAT services to test dial back
// and report the address on a successful connection.
DialBack(ctx context.Context, p peer.ID) error
}
// AddrFunc is a function returning the candidate addresses for the local host.
type AddrFunc func() []ma.Multiaddr
// Option is an Autonat option for configuration
type Option func(*config) error

View File

@@ -0,0 +1,162 @@
package autonat
import (
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/host/autonat/pb"
"github.com/libp2p/go-libp2p/p2p/metricshelper"
"github.com/prometheus/client_golang/prometheus"
)
const metricNamespace = "libp2p_autonat"
var (
reachabilityStatus = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "reachability_status",
Help: "Current node reachability",
},
)
reachabilityStatusConfidence = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "reachability_status_confidence",
Help: "Node reachability status confidence",
},
)
receivedDialResponseTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "received_dial_response_total",
Help: "Count of dial responses for client",
},
[]string{"response_status"},
)
outgoingDialResponseTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "outgoing_dial_response_total",
Help: "Count of dial responses for server",
},
[]string{"response_status"},
)
outgoingDialRefusedTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "outgoing_dial_refused_total",
Help: "Count of dial requests refused by server",
},
[]string{"refusal_reason"},
)
nextProbeTimestamp = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "next_probe_timestamp",
Help: "Time of next probe",
},
)
collectors = []prometheus.Collector{
reachabilityStatus,
reachabilityStatusConfidence,
receivedDialResponseTotal,
outgoingDialResponseTotal,
outgoingDialRefusedTotal,
nextProbeTimestamp,
}
)
type MetricsTracer interface {
ReachabilityStatus(status network.Reachability)
ReachabilityStatusConfidence(confidence int)
ReceivedDialResponse(status pb.Message_ResponseStatus)
OutgoingDialResponse(status pb.Message_ResponseStatus)
OutgoingDialRefused(reason string)
NextProbeTime(t time.Time)
}
func getResponseStatus(status pb.Message_ResponseStatus) string {
var s string
switch status {
case pb.Message_OK:
s = "ok"
case pb.Message_E_DIAL_ERROR:
s = "dial error"
case pb.Message_E_DIAL_REFUSED:
s = "dial refused"
case pb.Message_E_BAD_REQUEST:
s = "bad request"
case pb.Message_E_INTERNAL_ERROR:
s = "internal error"
default:
s = "unknown"
}
return s
}
const (
rate_limited = "rate limited"
dial_blocked = "dial blocked"
no_valid_address = "no valid address"
)
type metricsTracer struct{}
var _ MetricsTracer = &metricsTracer{}
type metricsTracerSetting struct {
reg prometheus.Registerer
}
type MetricsTracerOption func(*metricsTracerSetting)
func WithRegisterer(reg prometheus.Registerer) MetricsTracerOption {
return func(s *metricsTracerSetting) {
if reg != nil {
s.reg = reg
}
}
}
func NewMetricsTracer(opts ...MetricsTracerOption) MetricsTracer {
setting := &metricsTracerSetting{reg: prometheus.DefaultRegisterer}
for _, opt := range opts {
opt(setting)
}
metricshelper.RegisterCollectors(setting.reg, collectors...)
return &metricsTracer{}
}
func (mt *metricsTracer) ReachabilityStatus(status network.Reachability) {
reachabilityStatus.Set(float64(status))
}
func (mt *metricsTracer) ReachabilityStatusConfidence(confidence int) {
reachabilityStatusConfidence.Set(float64(confidence))
}
func (mt *metricsTracer) ReceivedDialResponse(status pb.Message_ResponseStatus) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, getResponseStatus(status))
receivedDialResponseTotal.WithLabelValues(*tags...).Inc()
}
func (mt *metricsTracer) OutgoingDialResponse(status pb.Message_ResponseStatus) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, getResponseStatus(status))
outgoingDialResponseTotal.WithLabelValues(*tags...).Inc()
}
func (mt *metricsTracer) OutgoingDialRefused(reason string) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, reason)
outgoingDialRefusedTotal.WithLabelValues(*tags...).Inc()
}
func (mt *metricsTracer) NextProbeTime(t time.Time) {
nextProbeTimestamp.Set(float64(t.Unix()))
}

View File

@@ -0,0 +1,30 @@
package autonat
import (
"github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
var _ network.Notifiee = (*AmbientAutoNAT)(nil)
// Listen is part of the network.Notifiee interface
func (as *AmbientAutoNAT) Listen(net network.Network, a ma.Multiaddr) {}
// ListenClose is part of the network.Notifiee interface
func (as *AmbientAutoNAT) ListenClose(net network.Network, a ma.Multiaddr) {}
// Connected is part of the network.Notifiee interface
func (as *AmbientAutoNAT) Connected(net network.Network, c network.Conn) {
if c.Stat().Direction == network.DirInbound &&
manet.IsPublicAddr(c.RemoteMultiaddr()) {
select {
case as.inboundConn <- c:
default:
}
}
}
// Disconnected is part of the network.Notifiee interface
func (as *AmbientAutoNAT) Disconnected(net network.Network, c network.Conn) {}

View File

@@ -0,0 +1,153 @@
package autonat
import (
"errors"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
)
// config holds configurable options for the autonat subsystem.
type config struct {
host host.Host
addressFunc AddrFunc
dialPolicy dialPolicy
dialer network.Network
forceReachability bool
reachability network.Reachability
metricsTracer MetricsTracer
// client
bootDelay time.Duration
retryInterval time.Duration
refreshInterval time.Duration
requestTimeout time.Duration
throttlePeerPeriod time.Duration
// server
dialTimeout time.Duration
maxPeerAddresses int
throttleGlobalMax int
throttlePeerMax int
throttleResetPeriod time.Duration
throttleResetJitter time.Duration
}
var defaults = func(c *config) error {
c.bootDelay = 15 * time.Second
c.retryInterval = 90 * time.Second
c.refreshInterval = 15 * time.Minute
c.requestTimeout = 30 * time.Second
c.throttlePeerPeriod = 90 * time.Second
c.dialTimeout = 15 * time.Second
c.maxPeerAddresses = 16
c.throttleGlobalMax = 30
c.throttlePeerMax = 3
c.throttleResetPeriod = 1 * time.Minute
c.throttleResetJitter = 15 * time.Second
return nil
}
// EnableService specifies that AutoNAT should be allowed to run a NAT service to help
// other peers determine their own NAT status. The provided Network should not be the
// default network/dialer of the host passed to `New`, as the NAT system will need to
// make parallel connections, and as such will modify both the associated peerstore
// and terminate connections of this dialer. The dialer provided
// should be compatible (TCP/UDP) however with the transports of the libp2p network.
func EnableService(dialer network.Network) Option {
return func(c *config) error {
if dialer == c.host.Network() || dialer.Peerstore() == c.host.Peerstore() {
return errors.New("dialer should not be that of the host")
}
c.dialer = dialer
return nil
}
}
// WithReachability overrides autonat to simply report an over-ridden reachability
// status.
func WithReachability(reachability network.Reachability) Option {
return func(c *config) error {
c.forceReachability = true
c.reachability = reachability
return nil
}
}
// UsingAddresses allows overriding which Addresses the AutoNAT client believes
// are "its own". Useful for testing, or for more exotic port-forwarding
// scenarios where the host may be listening on different ports than it wants
// to externally advertise or verify connectability on.
func UsingAddresses(addrFunc AddrFunc) Option {
return func(c *config) error {
if addrFunc == nil {
return errors.New("invalid address function supplied")
}
c.addressFunc = addrFunc
return nil
}
}
// WithSchedule configures how aggressively probes will be made to verify the
// address of the host. retryInterval indicates how often probes should be made
// when the host lacks confidence about its address, while refreshInterval
// is the schedule of periodic probes when the host believes it knows its
// steady-state reachability.
func WithSchedule(retryInterval, refreshInterval time.Duration) Option {
return func(c *config) error {
c.retryInterval = retryInterval
c.refreshInterval = refreshInterval
return nil
}
}
// WithoutStartupDelay removes the initial delay the NAT subsystem typically
// uses as a buffer for ensuring that connectivity and guesses as to the hosts
// local interfaces have settled down during startup.
func WithoutStartupDelay() Option {
return func(c *config) error {
c.bootDelay = 1
return nil
}
}
// WithoutThrottling indicates that this autonat service should not place
// restrictions on how many peers it is willing to help when acting as
// a server.
func WithoutThrottling() Option {
return func(c *config) error {
c.throttleGlobalMax = 0
return nil
}
}
// WithThrottling specifies how many peers (`amount`) it is willing to help
// ever `interval` amount of time when acting as a server.
func WithThrottling(amount int, interval time.Duration) Option {
return func(c *config) error {
c.throttleGlobalMax = amount
c.throttleResetPeriod = interval
c.throttleResetJitter = interval / 4
return nil
}
}
// WithPeerThrottling specifies a limit for the maximum number of IP checks
// this node will provide to an individual peer in each `interval`.
func WithPeerThrottling(amount int) Option {
return func(c *config) error {
c.throttlePeerMax = amount
return nil
}
}
// WithMetricsTracer uses mt to track autonat metrics
func WithMetricsTracer(mt MetricsTracer) Option {
return func(c *config) error {
c.metricsTracer = mt
return nil
}
}

View File

@@ -0,0 +1,524 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.30.0
// protoc v3.21.12
// source: pb/autonat.proto
package pb
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Message_MessageType int32
const (
Message_DIAL Message_MessageType = 0
Message_DIAL_RESPONSE Message_MessageType = 1
)
// Enum value maps for Message_MessageType.
var (
Message_MessageType_name = map[int32]string{
0: "DIAL",
1: "DIAL_RESPONSE",
}
Message_MessageType_value = map[string]int32{
"DIAL": 0,
"DIAL_RESPONSE": 1,
}
)
func (x Message_MessageType) Enum() *Message_MessageType {
p := new(Message_MessageType)
*p = x
return p
}
func (x Message_MessageType) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Message_MessageType) Descriptor() protoreflect.EnumDescriptor {
return file_pb_autonat_proto_enumTypes[0].Descriptor()
}
func (Message_MessageType) Type() protoreflect.EnumType {
return &file_pb_autonat_proto_enumTypes[0]
}
func (x Message_MessageType) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Do not use.
func (x *Message_MessageType) UnmarshalJSON(b []byte) error {
num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b)
if err != nil {
return err
}
*x = Message_MessageType(num)
return nil
}
// Deprecated: Use Message_MessageType.Descriptor instead.
func (Message_MessageType) EnumDescriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0, 0}
}
type Message_ResponseStatus int32
const (
Message_OK Message_ResponseStatus = 0
Message_E_DIAL_ERROR Message_ResponseStatus = 100
Message_E_DIAL_REFUSED Message_ResponseStatus = 101
Message_E_BAD_REQUEST Message_ResponseStatus = 200
Message_E_INTERNAL_ERROR Message_ResponseStatus = 300
)
// Enum value maps for Message_ResponseStatus.
var (
Message_ResponseStatus_name = map[int32]string{
0: "OK",
100: "E_DIAL_ERROR",
101: "E_DIAL_REFUSED",
200: "E_BAD_REQUEST",
300: "E_INTERNAL_ERROR",
}
Message_ResponseStatus_value = map[string]int32{
"OK": 0,
"E_DIAL_ERROR": 100,
"E_DIAL_REFUSED": 101,
"E_BAD_REQUEST": 200,
"E_INTERNAL_ERROR": 300,
}
)
func (x Message_ResponseStatus) Enum() *Message_ResponseStatus {
p := new(Message_ResponseStatus)
*p = x
return p
}
func (x Message_ResponseStatus) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Message_ResponseStatus) Descriptor() protoreflect.EnumDescriptor {
return file_pb_autonat_proto_enumTypes[1].Descriptor()
}
func (Message_ResponseStatus) Type() protoreflect.EnumType {
return &file_pb_autonat_proto_enumTypes[1]
}
func (x Message_ResponseStatus) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Do not use.
func (x *Message_ResponseStatus) UnmarshalJSON(b []byte) error {
num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b)
if err != nil {
return err
}
*x = Message_ResponseStatus(num)
return nil
}
// Deprecated: Use Message_ResponseStatus.Descriptor instead.
func (Message_ResponseStatus) EnumDescriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0, 1}
}
type Message struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type *Message_MessageType `protobuf:"varint,1,opt,name=type,enum=autonat.pb.Message_MessageType" json:"type,omitempty"`
Dial *Message_Dial `protobuf:"bytes,2,opt,name=dial" json:"dial,omitempty"`
DialResponse *Message_DialResponse `protobuf:"bytes,3,opt,name=dialResponse" json:"dialResponse,omitempty"`
}
func (x *Message) Reset() {
*x = Message{}
if protoimpl.UnsafeEnabled {
mi := &file_pb_autonat_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message) ProtoMessage() {}
func (x *Message) ProtoReflect() protoreflect.Message {
mi := &file_pb_autonat_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
func (*Message) Descriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0}
}
func (x *Message) GetType() Message_MessageType {
if x != nil && x.Type != nil {
return *x.Type
}
return Message_DIAL
}
func (x *Message) GetDial() *Message_Dial {
if x != nil {
return x.Dial
}
return nil
}
func (x *Message) GetDialResponse() *Message_DialResponse {
if x != nil {
return x.DialResponse
}
return nil
}
type Message_PeerInfo struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Id []byte `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"`
Addrs [][]byte `protobuf:"bytes,2,rep,name=addrs" json:"addrs,omitempty"`
}
func (x *Message_PeerInfo) Reset() {
*x = Message_PeerInfo{}
if protoimpl.UnsafeEnabled {
mi := &file_pb_autonat_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message_PeerInfo) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message_PeerInfo) ProtoMessage() {}
func (x *Message_PeerInfo) ProtoReflect() protoreflect.Message {
mi := &file_pb_autonat_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message_PeerInfo.ProtoReflect.Descriptor instead.
func (*Message_PeerInfo) Descriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0, 0}
}
func (x *Message_PeerInfo) GetId() []byte {
if x != nil {
return x.Id
}
return nil
}
func (x *Message_PeerInfo) GetAddrs() [][]byte {
if x != nil {
return x.Addrs
}
return nil
}
type Message_Dial struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Peer *Message_PeerInfo `protobuf:"bytes,1,opt,name=peer" json:"peer,omitempty"`
}
func (x *Message_Dial) Reset() {
*x = Message_Dial{}
if protoimpl.UnsafeEnabled {
mi := &file_pb_autonat_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message_Dial) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message_Dial) ProtoMessage() {}
func (x *Message_Dial) ProtoReflect() protoreflect.Message {
mi := &file_pb_autonat_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message_Dial.ProtoReflect.Descriptor instead.
func (*Message_Dial) Descriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0, 1}
}
func (x *Message_Dial) GetPeer() *Message_PeerInfo {
if x != nil {
return x.Peer
}
return nil
}
type Message_DialResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Status *Message_ResponseStatus `protobuf:"varint,1,opt,name=status,enum=autonat.pb.Message_ResponseStatus" json:"status,omitempty"`
StatusText *string `protobuf:"bytes,2,opt,name=statusText" json:"statusText,omitempty"`
Addr []byte `protobuf:"bytes,3,opt,name=addr" json:"addr,omitempty"`
}
func (x *Message_DialResponse) Reset() {
*x = Message_DialResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_pb_autonat_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message_DialResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message_DialResponse) ProtoMessage() {}
func (x *Message_DialResponse) ProtoReflect() protoreflect.Message {
mi := &file_pb_autonat_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message_DialResponse.ProtoReflect.Descriptor instead.
func (*Message_DialResponse) Descriptor() ([]byte, []int) {
return file_pb_autonat_proto_rawDescGZIP(), []int{0, 2}
}
func (x *Message_DialResponse) GetStatus() Message_ResponseStatus {
if x != nil && x.Status != nil {
return *x.Status
}
return Message_OK
}
func (x *Message_DialResponse) GetStatusText() string {
if x != nil && x.StatusText != nil {
return *x.StatusText
}
return ""
}
func (x *Message_DialResponse) GetAddr() []byte {
if x != nil {
return x.Addr
}
return nil
}
var File_pb_autonat_proto protoreflect.FileDescriptor
var file_pb_autonat_proto_rawDesc = []byte{
0x0a, 0x10, 0x70, 0x62, 0x2f, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x12, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x62, 0x22, 0xb5,
0x04, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x74, 0x79,
0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e,
0x61, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x4d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12,
0x2c, 0x0a, 0x04, 0x64, 0x69, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e,
0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61,
0x67, 0x65, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x04, 0x64, 0x69, 0x61, 0x6c, 0x12, 0x44, 0x0a,
0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x03, 0x20,
0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x62,
0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x1a, 0x30, 0x0a, 0x08, 0x50, 0x65, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x12,
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x02, 0x69, 0x64, 0x12,
0x14, 0x0a, 0x05, 0x61, 0x64, 0x64, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x05,
0x61, 0x64, 0x64, 0x72, 0x73, 0x1a, 0x38, 0x0a, 0x04, 0x44, 0x69, 0x61, 0x6c, 0x12, 0x30, 0x0a,
0x04, 0x70, 0x65, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x61, 0x75,
0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x2e, 0x50, 0x65, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x04, 0x70, 0x65, 0x65, 0x72, 0x1a,
0x7e, 0x0a, 0x0c, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
0x3a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
0x22, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73,
0x74, 0x61, 0x74, 0x75, 0x73, 0x54, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x0a, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x54, 0x65, 0x78, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x61,
0x64, 0x64, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x61, 0x64, 0x64, 0x72, 0x22,
0x2a, 0x0a, 0x0b, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x08,
0x0a, 0x04, 0x44, 0x49, 0x41, 0x4c, 0x10, 0x00, 0x12, 0x11, 0x0a, 0x0d, 0x44, 0x49, 0x41, 0x4c,
0x5f, 0x52, 0x45, 0x53, 0x50, 0x4f, 0x4e, 0x53, 0x45, 0x10, 0x01, 0x22, 0x69, 0x0a, 0x0e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x06, 0x0a,
0x02, 0x4f, 0x4b, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f,
0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x64, 0x12, 0x12, 0x0a, 0x0e, 0x45, 0x5f, 0x44, 0x49, 0x41,
0x4c, 0x5f, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x10, 0x65, 0x12, 0x12, 0x0a, 0x0d, 0x45,
0x5f, 0x42, 0x41, 0x44, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x45, 0x53, 0x54, 0x10, 0xc8, 0x01, 0x12,
0x15, 0x0a, 0x10, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x5f, 0x45, 0x52,
0x52, 0x4f, 0x52, 0x10, 0xac, 0x02,
}
var (
file_pb_autonat_proto_rawDescOnce sync.Once
file_pb_autonat_proto_rawDescData = file_pb_autonat_proto_rawDesc
)
func file_pb_autonat_proto_rawDescGZIP() []byte {
file_pb_autonat_proto_rawDescOnce.Do(func() {
file_pb_autonat_proto_rawDescData = protoimpl.X.CompressGZIP(file_pb_autonat_proto_rawDescData)
})
return file_pb_autonat_proto_rawDescData
}
var file_pb_autonat_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_pb_autonat_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_pb_autonat_proto_goTypes = []interface{}{
(Message_MessageType)(0), // 0: autonat.pb.Message.MessageType
(Message_ResponseStatus)(0), // 1: autonat.pb.Message.ResponseStatus
(*Message)(nil), // 2: autonat.pb.Message
(*Message_PeerInfo)(nil), // 3: autonat.pb.Message.PeerInfo
(*Message_Dial)(nil), // 4: autonat.pb.Message.Dial
(*Message_DialResponse)(nil), // 5: autonat.pb.Message.DialResponse
}
var file_pb_autonat_proto_depIdxs = []int32{
0, // 0: autonat.pb.Message.type:type_name -> autonat.pb.Message.MessageType
4, // 1: autonat.pb.Message.dial:type_name -> autonat.pb.Message.Dial
5, // 2: autonat.pb.Message.dialResponse:type_name -> autonat.pb.Message.DialResponse
3, // 3: autonat.pb.Message.Dial.peer:type_name -> autonat.pb.Message.PeerInfo
1, // 4: autonat.pb.Message.DialResponse.status:type_name -> autonat.pb.Message.ResponseStatus
5, // [5:5] is the sub-list for method output_type
5, // [5:5] is the sub-list for method input_type
5, // [5:5] is the sub-list for extension type_name
5, // [5:5] is the sub-list for extension extendee
0, // [0:5] is the sub-list for field type_name
}
func init() { file_pb_autonat_proto_init() }
func file_pb_autonat_proto_init() {
if File_pb_autonat_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pb_autonat_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pb_autonat_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message_PeerInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pb_autonat_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message_Dial); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pb_autonat_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message_DialResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pb_autonat_proto_rawDesc,
NumEnums: 2,
NumMessages: 4,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_pb_autonat_proto_goTypes,
DependencyIndexes: file_pb_autonat_proto_depIdxs,
EnumInfos: file_pb_autonat_proto_enumTypes,
MessageInfos: file_pb_autonat_proto_msgTypes,
}.Build()
File_pb_autonat_proto = out.File
file_pb_autonat_proto_rawDesc = nil
file_pb_autonat_proto_goTypes = nil
file_pb_autonat_proto_depIdxs = nil
}

View File

@@ -0,0 +1,37 @@
syntax = "proto2";
package autonat.pb;
message Message {
enum MessageType {
DIAL = 0;
DIAL_RESPONSE = 1;
}
enum ResponseStatus {
OK = 0;
E_DIAL_ERROR = 100;
E_DIAL_REFUSED = 101;
E_BAD_REQUEST = 200;
E_INTERNAL_ERROR = 300;
}
message PeerInfo {
optional bytes id = 1;
repeated bytes addrs = 2;
}
message Dial {
optional PeerInfo peer = 1;
}
message DialResponse {
optional ResponseStatus status = 1;
optional string statusText = 2;
optional bytes addr = 3;
}
optional MessageType type = 1;
optional Dial dial = 2;
optional DialResponse dialResponse = 3;
}

View File

@@ -0,0 +1,41 @@
package autonat
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/host/autonat/pb"
ma "github.com/multiformats/go-multiaddr"
)
//go:generate protoc --proto_path=$PWD:$PWD/../../.. --go_out=. --go_opt=Mpb/autonat.proto=./pb pb/autonat.proto
// AutoNATProto identifies the autonat service protocol
const AutoNATProto = "/libp2p/autonat/1.0.0"
func newDialMessage(pi peer.AddrInfo) *pb.Message {
msg := new(pb.Message)
msg.Type = pb.Message_DIAL.Enum()
msg.Dial = new(pb.Message_Dial)
msg.Dial.Peer = new(pb.Message_PeerInfo)
msg.Dial.Peer.Id = []byte(pi.ID)
msg.Dial.Peer.Addrs = make([][]byte, len(pi.Addrs))
for i, addr := range pi.Addrs {
msg.Dial.Peer.Addrs[i] = addr.Bytes()
}
return msg
}
func newDialResponseOK(addr ma.Multiaddr) *pb.Message_DialResponse {
dr := new(pb.Message_DialResponse)
dr.Status = pb.Message_OK.Enum()
dr.Addr = addr.Bytes()
return dr
}
func newDialResponseError(status pb.Message_ResponseStatus, text string) *pb.Message_DialResponse {
dr := new(pb.Message_DialResponse)
dr.Status = status.Enum()
dr.StatusText = &text
return dr
}

View File

@@ -0,0 +1,295 @@
package autonat
import (
"context"
"errors"
"math/rand"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/host/autonat/pb"
"github.com/libp2p/go-msgio/pbio"
ma "github.com/multiformats/go-multiaddr"
)
var streamTimeout = 60 * time.Second
const (
ServiceName = "libp2p.autonat"
maxMsgSize = 4096
)
// AutoNATService provides NAT autodetection services to other peers
type autoNATService struct {
instanceLock sync.Mutex
instance context.CancelFunc
backgroundRunning chan struct{} // closed when background exits
config *config
// rate limiter
mx sync.Mutex
reqs map[peer.ID]int
globalReqs int
}
// NewAutoNATService creates a new AutoNATService instance attached to a host
func newAutoNATService(c *config) (*autoNATService, error) {
if c.dialer == nil {
return nil, errors.New("cannot create NAT service without a network")
}
return &autoNATService{
config: c,
reqs: make(map[peer.ID]int),
}, nil
}
func (as *autoNATService) handleStream(s network.Stream) {
if err := s.Scope().SetService(ServiceName); err != nil {
log.Debugf("error attaching stream to autonat service: %s", err)
s.Reset()
return
}
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
log.Debugf("error reserving memory for autonat stream: %s", err)
s.Reset()
return
}
defer s.Scope().ReleaseMemory(maxMsgSize)
s.SetDeadline(time.Now().Add(streamTimeout))
defer s.Close()
pid := s.Conn().RemotePeer()
log.Debugf("New stream from %s", pid.Pretty())
r := pbio.NewDelimitedReader(s, maxMsgSize)
w := pbio.NewDelimitedWriter(s)
var req pb.Message
var res pb.Message
err := r.ReadMsg(&req)
if err != nil {
log.Debugf("Error reading message from %s: %s", pid.Pretty(), err.Error())
s.Reset()
return
}
t := req.GetType()
if t != pb.Message_DIAL {
log.Debugf("Unexpected message from %s: %s (%d)", pid.Pretty(), t.String(), t)
s.Reset()
return
}
dr := as.handleDial(pid, s.Conn().RemoteMultiaddr(), req.GetDial().GetPeer())
res.Type = pb.Message_DIAL_RESPONSE.Enum()
res.DialResponse = dr
err = w.WriteMsg(&res)
if err != nil {
log.Debugf("Error writing response to %s: %s", pid.Pretty(), err.Error())
s.Reset()
return
}
if as.config.metricsTracer != nil {
as.config.metricsTracer.OutgoingDialResponse(res.GetDialResponse().GetStatus())
}
}
func (as *autoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Message_PeerInfo) *pb.Message_DialResponse {
if mpi == nil {
return newDialResponseError(pb.Message_E_BAD_REQUEST, "missing peer info")
}
mpid := mpi.GetId()
if mpid != nil {
mp, err := peer.IDFromBytes(mpid)
if err != nil {
return newDialResponseError(pb.Message_E_BAD_REQUEST, "bad peer id")
}
if mp != p {
return newDialResponseError(pb.Message_E_BAD_REQUEST, "peer id mismatch")
}
}
addrs := make([]ma.Multiaddr, 0, as.config.maxPeerAddresses)
seen := make(map[string]struct{})
// Don't even try to dial peers with blocked remote addresses. In order to dial a peer, we
// need to know their public IP address, and it needs to be different from our public IP
// address.
if as.config.dialPolicy.skipDial(obsaddr) {
if as.config.metricsTracer != nil {
as.config.metricsTracer.OutgoingDialRefused(dial_blocked)
}
// Note: versions < v0.20.0 return Message_E_DIAL_ERROR here, thus we can not rely on this error code.
return newDialResponseError(pb.Message_E_DIAL_REFUSED, "refusing to dial peer with blocked observed address")
}
// Determine the peer's IP address.
hostIP, _ := ma.SplitFirst(obsaddr)
switch hostIP.Protocol().Code {
case ma.P_IP4, ma.P_IP6:
default:
// This shouldn't be possible as we should skip all addresses that don't include
// public IP addresses.
return newDialResponseError(pb.Message_E_INTERNAL_ERROR, "expected an IP address")
}
// add observed addr to the list of addresses to dial
addrs = append(addrs, obsaddr)
seen[obsaddr.String()] = struct{}{}
for _, maddr := range mpi.GetAddrs() {
addr, err := ma.NewMultiaddrBytes(maddr)
if err != nil {
log.Debugf("Error parsing multiaddr: %s", err.Error())
continue
}
// For security reasons, we _only_ dial the observed IP address.
// Replace other IP addresses with the observed one so we can still try the
// requested ports/transports.
if ip, rest := ma.SplitFirst(addr); !ip.Equal(hostIP) {
// Make sure it's an IP address
switch ip.Protocol().Code {
case ma.P_IP4, ma.P_IP6:
default:
continue
}
addr = hostIP
if rest != nil {
addr = addr.Encapsulate(rest)
}
}
// Make sure we're willing to dial the rest of the address (e.g., not a circuit
// address).
if as.config.dialPolicy.skipDial(addr) {
continue
}
str := addr.String()
_, ok := seen[str]
if ok {
continue
}
addrs = append(addrs, addr)
seen[str] = struct{}{}
if len(addrs) >= as.config.maxPeerAddresses {
break
}
}
if len(addrs) == 0 {
if as.config.metricsTracer != nil {
as.config.metricsTracer.OutgoingDialRefused(no_valid_address)
}
// Note: versions < v0.20.0 return Message_E_DIAL_ERROR here, thus we can not rely on this error code.
return newDialResponseError(pb.Message_E_DIAL_REFUSED, "no dialable addresses")
}
return as.doDial(peer.AddrInfo{ID: p, Addrs: addrs})
}
func (as *autoNATService) doDial(pi peer.AddrInfo) *pb.Message_DialResponse {
// rate limit check
as.mx.Lock()
count := as.reqs[pi.ID]
if count >= as.config.throttlePeerMax || (as.config.throttleGlobalMax > 0 &&
as.globalReqs >= as.config.throttleGlobalMax) {
as.mx.Unlock()
if as.config.metricsTracer != nil {
as.config.metricsTracer.OutgoingDialRefused(rate_limited)
}
return newDialResponseError(pb.Message_E_DIAL_REFUSED, "too many dials")
}
as.reqs[pi.ID] = count + 1
as.globalReqs++
as.mx.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), as.config.dialTimeout)
defer cancel()
as.config.dialer.Peerstore().ClearAddrs(pi.ID)
as.config.dialer.Peerstore().AddAddrs(pi.ID, pi.Addrs, peerstore.TempAddrTTL)
defer func() {
as.config.dialer.Peerstore().ClearAddrs(pi.ID)
as.config.dialer.Peerstore().RemovePeer(pi.ID)
}()
conn, err := as.config.dialer.DialPeer(ctx, pi.ID)
if err != nil {
log.Debugf("error dialing %s: %s", pi.ID.Pretty(), err.Error())
// wait for the context to timeout to avoid leaking timing information
// this renders the service ineffective as a port scanner
<-ctx.Done()
return newDialResponseError(pb.Message_E_DIAL_ERROR, "dial failed")
}
ra := conn.RemoteMultiaddr()
as.config.dialer.ClosePeer(pi.ID)
return newDialResponseOK(ra)
}
// Enable the autoNAT service if it is not running.
func (as *autoNATService) Enable() {
as.instanceLock.Lock()
defer as.instanceLock.Unlock()
if as.instance != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
as.instance = cancel
as.backgroundRunning = make(chan struct{})
as.config.host.SetStreamHandler(AutoNATProto, as.handleStream)
go as.background(ctx)
}
// Disable the autoNAT service if it is running.
func (as *autoNATService) Disable() {
as.instanceLock.Lock()
defer as.instanceLock.Unlock()
if as.instance != nil {
as.config.host.RemoveStreamHandler(AutoNATProto)
as.instance()
as.instance = nil
<-as.backgroundRunning
}
}
func (as *autoNATService) background(ctx context.Context) {
defer close(as.backgroundRunning)
timer := time.NewTimer(as.config.throttleResetPeriod)
defer timer.Stop()
for {
select {
case <-timer.C:
as.mx.Lock()
as.reqs = make(map[peer.ID]int)
as.globalReqs = 0
as.mx.Unlock()
jitter := rand.Float32() * float32(as.config.throttleResetJitter)
timer.Reset(as.config.throttleResetPeriod + time.Duration(int64(jitter)))
case <-ctx.Done():
return
}
}
}

View File

@@ -0,0 +1,165 @@
package autorelay
import (
"encoding/binary"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// This function cleans up a relay's address set to remove private addresses and curtail
// addrsplosion.
func cleanupAddressSet(addrs []ma.Multiaddr) []ma.Multiaddr {
var public, private []ma.Multiaddr
for _, a := range addrs {
if isRelayAddr(a) {
continue
}
if manet.IsPublicAddr(a) || isDNSAddr(a) {
public = append(public, a)
continue
}
// discard unroutable addrs
if manet.IsPrivateAddr(a) {
private = append(private, a)
}
}
if !hasAddrsplosion(public) {
return public
}
return sanitizeAddrsplodedSet(public, private)
}
func isRelayAddr(a ma.Multiaddr) bool {
isRelay := false
ma.ForEach(a, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_CIRCUIT:
isRelay = true
return false
default:
return true
}
})
return isRelay
}
func isDNSAddr(a ma.Multiaddr) bool {
if first, _ := ma.SplitFirst(a); first != nil {
switch first.Protocol().Code {
case ma.P_DNS4, ma.P_DNS6, ma.P_DNSADDR:
return true
}
}
return false
}
// we have addrsplosion if for some protocol we advertise multiple ports on
// the same base address.
func hasAddrsplosion(addrs []ma.Multiaddr) bool {
aset := make(map[string]int)
for _, a := range addrs {
key, port := addrKeyAndPort(a)
xport, ok := aset[key]
if ok && port != xport {
return true
}
aset[key] = port
}
return false
}
func addrKeyAndPort(a ma.Multiaddr) (string, int) {
var (
key string
port int
)
ma.ForEach(a, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_TCP, ma.P_UDP:
port = int(binary.BigEndian.Uint16(c.RawValue()))
key += "/" + c.Protocol().Name
default:
val := c.Value()
if val == "" {
val = c.Protocol().Name
}
key += "/" + val
}
return true
})
return key, port
}
// clean up addrsplosion
// the following heuristic is used:
// - for each base address/protocol combination, if there are multiple ports advertised then
// only accept the default port if present.
// - If the default port is not present, we check for non-standard ports by tracking
// private port bindings if present.
// - If there is no default or private port binding, then we can't infer the correct
// port and give up and return all addrs (for that base address)
func sanitizeAddrsplodedSet(public, private []ma.Multiaddr) []ma.Multiaddr {
type portAndAddr struct {
addr ma.Multiaddr
port int
}
privports := make(map[int]struct{})
pubaddrs := make(map[string][]portAndAddr)
for _, a := range private {
_, port := addrKeyAndPort(a)
privports[port] = struct{}{}
}
for _, a := range public {
key, port := addrKeyAndPort(a)
pubaddrs[key] = append(pubaddrs[key], portAndAddr{addr: a, port: port})
}
var result []ma.Multiaddr
for _, pas := range pubaddrs {
if len(pas) == 1 {
// it's not addrsploded
result = append(result, pas[0].addr)
continue
}
haveAddr := false
for _, pa := range pas {
if _, ok := privports[pa.port]; ok {
// it matches a privately bound port, use it
result = append(result, pa.addr)
haveAddr = true
continue
}
if pa.port == 4001 || pa.port == 4002 {
// it's a default port, use it
result = append(result, pa.addr)
haveAddr = true
}
}
if !haveAddr {
// we weren't able to select a port; bite the bullet and use them all
for _, pa := range pas {
result = append(result, pa.addr)
}
}
}
return result
}

View File

@@ -0,0 +1,125 @@
package autorelay
import (
"context"
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
basic "github.com/libp2p/go-libp2p/p2p/host/basic"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
)
var log = logging.Logger("autorelay")
type AutoRelay struct {
refCount sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
conf *config
mx sync.Mutex
status network.Reachability
relayFinder *relayFinder
host host.Host
addrsF basic.AddrsFactory
metricsTracer MetricsTracer
}
func NewAutoRelay(bhost *basic.BasicHost, opts ...Option) (*AutoRelay, error) {
r := &AutoRelay{
host: bhost,
addrsF: bhost.AddrsFactory,
status: network.ReachabilityUnknown,
}
conf := defaultConfig
for _, opt := range opts {
if err := opt(&conf); err != nil {
return nil, err
}
}
r.ctx, r.ctxCancel = context.WithCancel(context.Background())
r.conf = &conf
r.relayFinder = newRelayFinder(bhost, conf.peerSource, &conf)
r.metricsTracer = &wrappedMetricsTracer{conf.metricsTracer}
bhost.AddrsFactory = r.hostAddrs
return r, nil
}
func (r *AutoRelay) Start() {
r.refCount.Add(1)
go func() {
defer r.refCount.Done()
r.background()
}()
}
func (r *AutoRelay) background() {
subReachability, err := r.host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("autorelay (background)"))
if err != nil {
log.Debug("failed to subscribe to the EvtLocalReachabilityChanged")
return
}
defer subReachability.Close()
for {
select {
case <-r.ctx.Done():
return
case ev, ok := <-subReachability.Out():
if !ok {
return
}
// TODO: push changed addresses
evt := ev.(event.EvtLocalReachabilityChanged)
switch evt.Reachability {
case network.ReachabilityPrivate, network.ReachabilityUnknown:
err := r.relayFinder.Start()
if errors.Is(err, errAlreadyRunning) {
log.Debug("tried to start already running relay finder")
} else if err != nil {
log.Errorw("failed to start relay finder", "error", err)
} else {
r.metricsTracer.RelayFinderStatus(true)
}
case network.ReachabilityPublic:
r.relayFinder.Stop()
r.metricsTracer.RelayFinderStatus(false)
}
r.mx.Lock()
r.status = evt.Reachability
r.mx.Unlock()
}
}
}
func (r *AutoRelay) hostAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
return r.relayAddrs(r.addrsF(addrs))
}
func (r *AutoRelay) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
r.mx.Lock()
defer r.mx.Unlock()
if r.status != network.ReachabilityPrivate {
return addrs
}
return r.relayFinder.relayAddrs(addrs)
}
func (r *AutoRelay) Close() error {
r.ctxCancel()
err := r.relayFinder.Stop()
r.refCount.Wait()
return err
}

View File

@@ -0,0 +1,23 @@
package autorelay
import (
"github.com/libp2p/go-libp2p/core/host"
)
type AutoRelayHost struct {
host.Host
ar *AutoRelay
}
func (h *AutoRelayHost) Close() error {
_ = h.ar.Close()
return h.Host.Close()
}
func (h *AutoRelayHost) Start() {
h.ar.Start()
}
func NewAutoRelayHost(h host.Host, ar *AutoRelay) *AutoRelayHost {
return &AutoRelayHost{Host: h, ar: ar}
}

View File

@@ -0,0 +1,373 @@
package autorelay
import (
"errors"
"github.com/libp2p/go-libp2p/p2p/metricshelper"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
pbv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/pb"
"github.com/prometheus/client_golang/prometheus"
)
const metricNamespace = "libp2p_autorelay"
var (
status = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "status",
Help: "relay finder active",
})
reservationsOpenedTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "reservations_opened_total",
Help: "Reservations Opened",
},
)
reservationsClosedTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "reservations_closed_total",
Help: "Reservations Closed",
},
)
reservationRequestsOutcomeTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "reservation_requests_outcome_total",
Help: "Reservation Request Outcome",
},
[]string{"request_type", "outcome"},
)
relayAddressesUpdatedTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "relay_addresses_updated_total",
Help: "Relay Addresses Updated Count",
},
)
relayAddressesCount = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "relay_addresses_count",
Help: "Relay Addresses Count",
},
)
candidatesCircuitV2SupportTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "candidates_circuit_v2_support_total",
Help: "Candidiates supporting circuit v2",
},
[]string{"support"},
)
candidatesTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "candidates_total",
Help: "Candidates Total",
},
[]string{"type"},
)
candLoopState = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "candidate_loop_state",
Help: "Candidate Loop State",
},
)
scheduledWorkTime = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "scheduled_work_time",
Help: "Scheduled Work Times",
},
[]string{"work_type"},
)
desiredReservations = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "desired_reservations",
Help: "Desired Reservations",
},
)
collectors = []prometheus.Collector{
status,
reservationsOpenedTotal,
reservationsClosedTotal,
reservationRequestsOutcomeTotal,
relayAddressesUpdatedTotal,
relayAddressesCount,
candidatesCircuitV2SupportTotal,
candidatesTotal,
candLoopState,
scheduledWorkTime,
desiredReservations,
}
)
type candidateLoopState int
const (
peerSourceRateLimited candidateLoopState = iota
waitingOnPeerChan
waitingForTrigger
stopped
)
// MetricsTracer is the interface for tracking metrics for autorelay
type MetricsTracer interface {
RelayFinderStatus(isActive bool)
ReservationEnded(cnt int)
ReservationOpened(cnt int)
ReservationRequestFinished(isRefresh bool, err error)
RelayAddressCount(int)
RelayAddressUpdated()
CandidateChecked(supportsCircuitV2 bool)
CandidateAdded(cnt int)
CandidateRemoved(cnt int)
CandidateLoopState(state candidateLoopState)
ScheduledWorkUpdated(scheduledWork *scheduledWorkTimes)
DesiredReservations(int)
}
type metricsTracer struct{}
var _ MetricsTracer = &metricsTracer{}
type metricsTracerSetting struct {
reg prometheus.Registerer
}
type MetricsTracerOption func(*metricsTracerSetting)
func WithRegisterer(reg prometheus.Registerer) MetricsTracerOption {
return func(s *metricsTracerSetting) {
if reg != nil {
s.reg = reg
}
}
}
func NewMetricsTracer(opts ...MetricsTracerOption) MetricsTracer {
setting := &metricsTracerSetting{reg: prometheus.DefaultRegisterer}
for _, opt := range opts {
opt(setting)
}
metricshelper.RegisterCollectors(setting.reg, collectors...)
// Initialise these counters to 0 otherwise the first reservation requests aren't handled
// correctly when using promql increse function
reservationRequestsOutcomeTotal.WithLabelValues("refresh", "success")
reservationRequestsOutcomeTotal.WithLabelValues("new", "success")
candidatesCircuitV2SupportTotal.WithLabelValues("yes")
candidatesCircuitV2SupportTotal.WithLabelValues("no")
return &metricsTracer{}
}
func (mt *metricsTracer) RelayFinderStatus(isActive bool) {
if isActive {
status.Set(1)
} else {
status.Set(0)
}
}
func (mt *metricsTracer) ReservationEnded(cnt int) {
reservationsClosedTotal.Add(float64(cnt))
}
func (mt *metricsTracer) ReservationOpened(cnt int) {
reservationsOpenedTotal.Add(float64(cnt))
}
func (mt *metricsTracer) ReservationRequestFinished(isRefresh bool, err error) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
if isRefresh {
*tags = append(*tags, "refresh")
} else {
*tags = append(*tags, "new")
}
*tags = append(*tags, getReservationRequestStatus(err))
reservationRequestsOutcomeTotal.WithLabelValues(*tags...).Inc()
if !isRefresh && err == nil {
reservationsOpenedTotal.Inc()
}
}
func (mt *metricsTracer) RelayAddressUpdated() {
relayAddressesUpdatedTotal.Inc()
}
func (mt *metricsTracer) RelayAddressCount(cnt int) {
relayAddressesCount.Set(float64(cnt))
}
func (mt *metricsTracer) CandidateChecked(supportsCircuitV2 bool) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
if supportsCircuitV2 {
*tags = append(*tags, "yes")
} else {
*tags = append(*tags, "no")
}
candidatesCircuitV2SupportTotal.WithLabelValues(*tags...).Inc()
}
func (mt *metricsTracer) CandidateAdded(cnt int) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, "added")
candidatesTotal.WithLabelValues(*tags...).Add(float64(cnt))
}
func (mt *metricsTracer) CandidateRemoved(cnt int) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, "removed")
candidatesTotal.WithLabelValues(*tags...).Add(float64(cnt))
}
func (mt *metricsTracer) CandidateLoopState(state candidateLoopState) {
candLoopState.Set(float64(state))
}
func (mt *metricsTracer) ScheduledWorkUpdated(scheduledWork *scheduledWorkTimes) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, "allowed peer source call")
scheduledWorkTime.WithLabelValues(*tags...).Set(float64(scheduledWork.nextAllowedCallToPeerSource.Unix()))
*tags = (*tags)[:0]
*tags = append(*tags, "reservation refresh")
scheduledWorkTime.WithLabelValues(*tags...).Set(float64(scheduledWork.nextRefresh.Unix()))
*tags = (*tags)[:0]
*tags = append(*tags, "clear backoff")
scheduledWorkTime.WithLabelValues(*tags...).Set(float64(scheduledWork.nextBackoff.Unix()))
*tags = (*tags)[:0]
*tags = append(*tags, "old candidate check")
scheduledWorkTime.WithLabelValues(*tags...).Set(float64(scheduledWork.nextOldCandidateCheck.Unix()))
}
func (mt *metricsTracer) DesiredReservations(cnt int) {
desiredReservations.Set(float64(cnt))
}
func getReservationRequestStatus(err error) string {
if err == nil {
return "success"
}
status := "err other"
var re client.ReservationError
if errors.As(err, &re) {
switch re.Status {
case pbv2.Status_CONNECTION_FAILED:
return "connection failed"
case pbv2.Status_MALFORMED_MESSAGE:
return "malformed message"
case pbv2.Status_RESERVATION_REFUSED:
return "reservation refused"
case pbv2.Status_PERMISSION_DENIED:
return "permission denied"
case pbv2.Status_RESOURCE_LIMIT_EXCEEDED:
return "resource limit exceeded"
}
}
return status
}
// wrappedMetricsTracer wraps MetricsTracer and ignores all calls when mt is nil
type wrappedMetricsTracer struct {
mt MetricsTracer
}
var _ MetricsTracer = &wrappedMetricsTracer{}
func (mt *wrappedMetricsTracer) RelayFinderStatus(isActive bool) {
if mt.mt != nil {
mt.mt.RelayFinderStatus(isActive)
}
}
func (mt *wrappedMetricsTracer) ReservationEnded(cnt int) {
if mt.mt != nil {
mt.mt.ReservationEnded(cnt)
}
}
func (mt *wrappedMetricsTracer) ReservationOpened(cnt int) {
if mt.mt != nil {
mt.mt.ReservationOpened(cnt)
}
}
func (mt *wrappedMetricsTracer) ReservationRequestFinished(isRefresh bool, err error) {
if mt.mt != nil {
mt.mt.ReservationRequestFinished(isRefresh, err)
}
}
func (mt *wrappedMetricsTracer) RelayAddressUpdated() {
if mt.mt != nil {
mt.mt.RelayAddressUpdated()
}
}
func (mt *wrappedMetricsTracer) RelayAddressCount(cnt int) {
if mt.mt != nil {
mt.mt.RelayAddressCount(cnt)
}
}
func (mt *wrappedMetricsTracer) CandidateChecked(supportsCircuitV2 bool) {
if mt.mt != nil {
mt.mt.CandidateChecked(supportsCircuitV2)
}
}
func (mt *wrappedMetricsTracer) CandidateAdded(cnt int) {
if mt.mt != nil {
mt.mt.CandidateAdded(cnt)
}
}
func (mt *wrappedMetricsTracer) CandidateRemoved(cnt int) {
if mt.mt != nil {
mt.mt.CandidateRemoved(cnt)
}
}
func (mt *wrappedMetricsTracer) ScheduledWorkUpdated(scheduledWork *scheduledWorkTimes) {
if mt.mt != nil {
mt.mt.ScheduledWorkUpdated(scheduledWork)
}
}
func (mt *wrappedMetricsTracer) DesiredReservations(cnt int) {
if mt.mt != nil {
mt.mt.DesiredReservations(cnt)
}
}
func (mt *wrappedMetricsTracer) CandidateLoopState(state candidateLoopState) {
if mt.mt != nil {
mt.mt.CandidateLoopState(state)
}
}

View File

@@ -0,0 +1,233 @@
package autorelay
import (
"context"
"errors"
"time"
"github.com/libp2p/go-libp2p/core/peer"
)
// AutoRelay will call this function when it needs new candidates because it is
// not connected to the desired number of relays or we get disconnected from one
// of the relays. Implementations must send *at most* numPeers, and close the
// channel when they don't intend to provide any more peers. AutoRelay will not
// call the callback again until the channel is closed. Implementations should
// send new peers, but may send peers they sent before. AutoRelay implements a
// per-peer backoff (see WithBackoff). See WithMinInterval for setting the
// minimum interval between calls to the callback. The context.Context passed
// may be canceled when AutoRelay feels satisfied, it will be canceled when the
// node is shutting down. If the context is canceled you MUST close the output
// channel at some point.
type PeerSource func(ctx context.Context, num int) <-chan peer.AddrInfo
type config struct {
clock ClockWithInstantTimer
peerSource PeerSource
// minimum interval used to call the peerSource callback
minInterval time.Duration
// see WithMinCandidates
minCandidates int
// see WithMaxCandidates
maxCandidates int
// Delay until we obtain reservations with relays, if we have less than minCandidates candidates.
// See WithBootDelay.
bootDelay time.Duration
// backoff is the time we wait after failing to obtain a reservation with a candidate
backoff time.Duration
// Number of relays we strive to obtain a reservation with.
desiredRelays int
// see WithMaxCandidateAge
maxCandidateAge time.Duration
setMinCandidates bool
// see WithMetricsTracer
metricsTracer MetricsTracer
}
var defaultConfig = config{
clock: RealClock{},
minCandidates: 4,
maxCandidates: 20,
bootDelay: 3 * time.Minute,
backoff: time.Hour,
desiredRelays: 2,
maxCandidateAge: 30 * time.Minute,
minInterval: 30 * time.Second,
}
var (
errAlreadyHavePeerSource = errors.New("can only use a single WithPeerSource or WithStaticRelays")
)
type Option func(*config) error
func WithStaticRelays(static []peer.AddrInfo) Option {
return func(c *config) error {
if c.peerSource != nil {
return errAlreadyHavePeerSource
}
WithPeerSource(func(ctx context.Context, numPeers int) <-chan peer.AddrInfo {
if len(static) < numPeers {
numPeers = len(static)
}
c := make(chan peer.AddrInfo, numPeers)
defer close(c)
for i := 0; i < numPeers; i++ {
c <- static[i]
}
return c
})(c)
WithMinCandidates(len(static))(c)
WithMaxCandidates(len(static))(c)
WithNumRelays(len(static))(c)
return nil
}
}
// WithPeerSource defines a callback for AutoRelay to query for more relay candidates.
func WithPeerSource(f PeerSource) Option {
return func(c *config) error {
if c.peerSource != nil {
return errAlreadyHavePeerSource
}
c.peerSource = f
return nil
}
}
// WithNumRelays sets the number of relays we strive to obtain reservations with.
func WithNumRelays(n int) Option {
return func(c *config) error {
c.desiredRelays = n
return nil
}
}
// WithMaxCandidates sets the number of relay candidates that we buffer.
func WithMaxCandidates(n int) Option {
return func(c *config) error {
c.maxCandidates = n
if c.minCandidates > n {
c.minCandidates = n
}
return nil
}
}
// WithMinCandidates sets the minimum number of relay candidates we collect before to get a reservation
// with any of them (unless we've been running for longer than the boot delay).
// This is to make sure that we don't just randomly connect to the first candidate that we discover.
func WithMinCandidates(n int) Option {
return func(c *config) error {
if n > c.maxCandidates {
n = c.maxCandidates
}
c.minCandidates = n
c.setMinCandidates = true
return nil
}
}
// WithBootDelay set the boot delay for finding relays.
// We won't attempt any reservation if we've have less than a minimum number of candidates.
// This prevents us to connect to the "first best" relay, and allows us to carefully select the relay.
// However, in case we haven't found enough relays after the boot delay, we use what we have.
func WithBootDelay(d time.Duration) Option {
return func(c *config) error {
c.bootDelay = d
return nil
}
}
// WithBackoff sets the time we wait after failing to obtain a reservation with a candidate.
func WithBackoff(d time.Duration) Option {
return func(c *config) error {
c.backoff = d
return nil
}
}
// WithMaxCandidateAge sets the maximum age of a candidate.
// When we are connected to the desired number of relays, we don't ask the peer source for new candidates.
// This can lead to AutoRelay's candidate list becoming outdated, and means we won't be able
// to quickly establish a new relay connection if our existing connection breaks, if all the candidates
// have become stale.
func WithMaxCandidateAge(d time.Duration) Option {
return func(c *config) error {
c.maxCandidateAge = d
return nil
}
}
// InstantTimer is a timer that triggers at some instant rather than some duration
type InstantTimer interface {
Reset(d time.Time) bool
Stop() bool
Ch() <-chan time.Time
}
// ClockWithInstantTimer is a clock that can create timers that trigger at some
// instant rather than some duration
type ClockWithInstantTimer interface {
Now() time.Time
Since(t time.Time) time.Duration
InstantTimer(when time.Time) InstantTimer
}
type RealTimer struct{ t *time.Timer }
var _ InstantTimer = (*RealTimer)(nil)
func (t RealTimer) Ch() <-chan time.Time {
return t.t.C
}
func (t RealTimer) Reset(d time.Time) bool {
return t.t.Reset(time.Until(d))
}
func (t RealTimer) Stop() bool {
return t.t.Stop()
}
type RealClock struct{}
var _ ClockWithInstantTimer = RealClock{}
func (RealClock) Now() time.Time {
return time.Now()
}
func (RealClock) Since(t time.Time) time.Duration {
return time.Since(t)
}
func (RealClock) InstantTimer(when time.Time) InstantTimer {
t := time.NewTimer(time.Until(when))
return &RealTimer{t}
}
func WithClock(cl ClockWithInstantTimer) Option {
return func(c *config) error {
c.clock = cl
return nil
}
}
// WithMinInterval sets the minimum interval after which peerSource callback will be called for more
// candidates even if AutoRelay needs new candidates.
func WithMinInterval(interval time.Duration) Option {
return func(c *config) error {
c.minInterval = interval
return nil
}
}
// WithMetricsTracer configures autorelay to use mt to track metrics
func WithMetricsTracer(mt MetricsTracer) Option {
return func(c *config) error {
c.metricsTracer = mt
return nil
}
}

View File

@@ -0,0 +1,17 @@
package autorelay
import (
ma "github.com/multiformats/go-multiaddr"
)
// Filter filters out all relay addresses.
func Filter(addrs []ma.Multiaddr) []ma.Multiaddr {
raddrs := make([]ma.Multiaddr, 0, len(addrs))
for _, addr := range addrs {
if isRelayAddr(addr) {
continue
}
raddrs = append(raddrs, addr)
}
return raddrs
}

View File

@@ -0,0 +1,810 @@
package autorelay
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"golang.org/x/sync/errgroup"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
basic "github.com/libp2p/go-libp2p/p2p/host/basic"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
circuitv2_proto "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
const protoIDv2 = circuitv2_proto.ProtoIDv2Hop
// Terminology:
// Candidate: Once we connect to a node and it supports relay protocol,
// we call it a candidate, and consider using it as a relay.
// Relay: Out of the list of candidates, we select a relay to connect to.
// Currently, we just randomly select a candidate, but we can employ more sophisticated
// selection strategies here (e.g. by facotring in the RTT).
const (
rsvpRefreshInterval = time.Minute
rsvpExpirationSlack = 2 * time.Minute
autorelayTag = "autorelay"
)
type candidate struct {
added time.Time
supportsRelayV2 bool
ai peer.AddrInfo
}
// relayFinder is a Host that uses relays for connectivity when a NAT is detected.
type relayFinder struct {
bootTime time.Time
host *basic.BasicHost
conf *config
refCount sync.WaitGroup
ctxCancel context.CancelFunc
ctxCancelMx sync.Mutex
peerSource PeerSource
candidateFound chan struct{} // receives every time we find a new relay candidate
candidateMx sync.Mutex
candidates map[peer.ID]*candidate
backoff map[peer.ID]time.Time
maybeConnectToRelayTrigger chan struct{} // cap: 1
// Any time _something_ hapens that might cause us to need new candidates.
// This could be
// * the disconnection of a relay
// * the failed attempt to obtain a reservation with a current candidate
// * a candidate is deleted due to its age
maybeRequestNewCandidates chan struct{} // cap: 1.
relayUpdated chan struct{}
relayMx sync.Mutex
relays map[peer.ID]*circuitv2.Reservation
cachedAddrs []ma.Multiaddr
cachedAddrsExpiry time.Time
// A channel that triggers a run of `runScheduledWork`.
triggerRunScheduledWork chan struct{}
metricsTracer MetricsTracer
}
var errAlreadyRunning = errors.New("relayFinder already running")
func newRelayFinder(host *basic.BasicHost, peerSource PeerSource, conf *config) *relayFinder {
if peerSource == nil {
panic("Can not create a new relayFinder. Need a Peer Source fn or a list of static relays. Refer to the documentation around `libp2p.EnableAutoRelay`")
}
return &relayFinder{
bootTime: conf.clock.Now(),
host: host,
conf: conf,
peerSource: peerSource,
candidates: make(map[peer.ID]*candidate),
backoff: make(map[peer.ID]time.Time),
candidateFound: make(chan struct{}, 1),
maybeConnectToRelayTrigger: make(chan struct{}, 1),
maybeRequestNewCandidates: make(chan struct{}, 1),
triggerRunScheduledWork: make(chan struct{}, 1),
relays: make(map[peer.ID]*circuitv2.Reservation),
relayUpdated: make(chan struct{}, 1),
metricsTracer: &wrappedMetricsTracer{conf.metricsTracer},
}
}
type scheduledWorkTimes struct {
leastFrequentInterval time.Duration
nextRefresh time.Time
nextBackoff time.Time
nextOldCandidateCheck time.Time
nextAllowedCallToPeerSource time.Time
}
func (rf *relayFinder) background(ctx context.Context) {
peerSourceRateLimiter := make(chan struct{}, 1)
rf.refCount.Add(1)
go func() {
defer rf.refCount.Done()
rf.findNodes(ctx, peerSourceRateLimiter)
}()
rf.refCount.Add(1)
go func() {
defer rf.refCount.Done()
rf.handleNewCandidates(ctx)
}()
subConnectedness, err := rf.host.EventBus().Subscribe(new(event.EvtPeerConnectednessChanged), eventbus.Name("autorelay (relay finder)"))
if err != nil {
log.Error("failed to subscribe to the EvtPeerConnectednessChanged")
return
}
defer subConnectedness.Close()
now := rf.conf.clock.Now()
bootDelayTimer := rf.conf.clock.InstantTimer(now.Add(rf.conf.bootDelay))
defer bootDelayTimer.Stop()
// This is the least frequent event. It's our fallback timer if we don't have any other work to do.
leastFrequentInterval := rf.conf.minInterval
// Check if leastFrequentInterval is 0 to avoid busy looping
if rf.conf.backoff > leastFrequentInterval || leastFrequentInterval == 0 {
leastFrequentInterval = rf.conf.backoff
}
if rf.conf.maxCandidateAge > leastFrequentInterval || leastFrequentInterval == 0 {
leastFrequentInterval = rf.conf.maxCandidateAge
}
if rsvpRefreshInterval > leastFrequentInterval || leastFrequentInterval == 0 {
leastFrequentInterval = rsvpRefreshInterval
}
scheduledWork := &scheduledWorkTimes{
leastFrequentInterval: leastFrequentInterval,
nextRefresh: now.Add(rsvpRefreshInterval),
nextBackoff: now.Add(rf.conf.backoff),
nextOldCandidateCheck: now.Add(rf.conf.maxCandidateAge),
nextAllowedCallToPeerSource: now.Add(-time.Second), // allow immediately
}
workTimer := rf.conf.clock.InstantTimer(rf.runScheduledWork(ctx, now, scheduledWork, peerSourceRateLimiter))
defer workTimer.Stop()
for {
select {
case ev, ok := <-subConnectedness.Out():
if !ok {
return
}
evt := ev.(event.EvtPeerConnectednessChanged)
if evt.Connectedness != network.NotConnected {
continue
}
push := false
rf.relayMx.Lock()
if rf.usingRelay(evt.Peer) { // we were disconnected from a relay
log.Debugw("disconnected from relay", "id", evt.Peer)
delete(rf.relays, evt.Peer)
rf.notifyMaybeConnectToRelay()
rf.notifyMaybeNeedNewCandidates()
push = true
}
rf.relayMx.Unlock()
if push {
rf.clearCachedAddrsAndSignalAddressChange()
rf.metricsTracer.ReservationEnded(1)
}
case <-rf.candidateFound:
rf.notifyMaybeConnectToRelay()
case <-bootDelayTimer.Ch():
rf.notifyMaybeConnectToRelay()
case <-rf.relayUpdated:
rf.clearCachedAddrsAndSignalAddressChange()
case now := <-workTimer.Ch():
// Note: `now` is not guaranteed to be the current time. It's the time
// that the timer was fired. This is okay because we'll schedule
// future work at a specific time.
nextTime := rf.runScheduledWork(ctx, now, scheduledWork, peerSourceRateLimiter)
workTimer.Reset(nextTime)
case <-rf.triggerRunScheduledWork:
// Ignore the next time because we aren't scheduling any future work here
_ = rf.runScheduledWork(ctx, rf.conf.clock.Now(), scheduledWork, peerSourceRateLimiter)
case <-ctx.Done():
return
}
}
}
func (rf *relayFinder) clearCachedAddrsAndSignalAddressChange() {
rf.relayMx.Lock()
rf.cachedAddrs = nil
rf.relayMx.Unlock()
rf.host.SignalAddressChange()
rf.metricsTracer.RelayAddressUpdated()
}
func (rf *relayFinder) runScheduledWork(ctx context.Context, now time.Time, scheduledWork *scheduledWorkTimes, peerSourceRateLimiter chan<- struct{}) time.Time {
nextTime := now.Add(scheduledWork.leastFrequentInterval)
if now.After(scheduledWork.nextRefresh) {
scheduledWork.nextRefresh = now.Add(rsvpRefreshInterval)
if rf.refreshReservations(ctx, now) {
rf.clearCachedAddrsAndSignalAddressChange()
}
}
if now.After(scheduledWork.nextBackoff) {
scheduledWork.nextBackoff = rf.clearBackoff(now)
}
if now.After(scheduledWork.nextOldCandidateCheck) {
scheduledWork.nextOldCandidateCheck = rf.clearOldCandidates(now)
}
if now.After(scheduledWork.nextAllowedCallToPeerSource) {
select {
case peerSourceRateLimiter <- struct{}{}:
scheduledWork.nextAllowedCallToPeerSource = now.Add(rf.conf.minInterval)
if scheduledWork.nextAllowedCallToPeerSource.Before(nextTime) {
nextTime = scheduledWork.nextAllowedCallToPeerSource
}
default:
}
} else {
// We still need to schedule this work if it's sooner than nextTime
if scheduledWork.nextAllowedCallToPeerSource.Before(nextTime) {
nextTime = scheduledWork.nextAllowedCallToPeerSource
}
}
// Find the next time we need to run scheduled work.
if scheduledWork.nextRefresh.Before(nextTime) {
nextTime = scheduledWork.nextRefresh
}
if scheduledWork.nextBackoff.Before(nextTime) {
nextTime = scheduledWork.nextBackoff
}
if scheduledWork.nextOldCandidateCheck.Before(nextTime) {
nextTime = scheduledWork.nextOldCandidateCheck
}
if nextTime == now {
// Only happens in CI with a mock clock
nextTime = nextTime.Add(1) // avoids an infinite loop
}
rf.metricsTracer.ScheduledWorkUpdated(scheduledWork)
return nextTime
}
// clearOldCandidates clears old candidates from the map. Returns the next time
// to run this function.
func (rf *relayFinder) clearOldCandidates(now time.Time) time.Time {
// If we don't have any candidates, we should run this again in rf.conf.maxCandidateAge.
nextTime := now.Add(rf.conf.maxCandidateAge)
var deleted bool
rf.candidateMx.Lock()
defer rf.candidateMx.Unlock()
for id, cand := range rf.candidates {
expiry := cand.added.Add(rf.conf.maxCandidateAge)
if expiry.After(now) {
if expiry.Before(nextTime) {
nextTime = expiry
}
} else {
log.Debugw("deleting candidate due to age", "id", id)
deleted = true
rf.removeCandidate(id)
}
}
if deleted {
rf.notifyMaybeNeedNewCandidates()
}
return nextTime
}
// clearBackoff clears old backoff entries from the map. Returns the next time
// to run this function.
func (rf *relayFinder) clearBackoff(now time.Time) time.Time {
nextTime := now.Add(rf.conf.backoff)
rf.candidateMx.Lock()
defer rf.candidateMx.Unlock()
for id, t := range rf.backoff {
expiry := t.Add(rf.conf.backoff)
if expiry.After(now) {
if expiry.Before(nextTime) {
nextTime = expiry
}
} else {
log.Debugw("removing backoff for node", "id", id)
delete(rf.backoff, id)
}
}
return nextTime
}
// findNodes accepts nodes from the channel and tests if they support relaying.
// It is run on both public and private nodes.
// It garbage collects old entries, so that nodes doesn't overflow.
// This makes sure that as soon as we need to find relay candidates, we have them available.
// peerSourceRateLimiter is used to limit how often we call the peer source.
func (rf *relayFinder) findNodes(ctx context.Context, peerSourceRateLimiter <-chan struct{}) {
var peerChan <-chan peer.AddrInfo
var wg sync.WaitGroup
for {
rf.candidateMx.Lock()
numCandidates := len(rf.candidates)
rf.candidateMx.Unlock()
if peerChan == nil && numCandidates < rf.conf.minCandidates {
rf.metricsTracer.CandidateLoopState(peerSourceRateLimited)
select {
case <-peerSourceRateLimiter:
peerChan = rf.peerSource(ctx, rf.conf.maxCandidates)
select {
case rf.triggerRunScheduledWork <- struct{}{}:
default:
}
case <-ctx.Done():
return
}
}
if peerChan == nil {
rf.metricsTracer.CandidateLoopState(waitingForTrigger)
} else {
rf.metricsTracer.CandidateLoopState(waitingOnPeerChan)
}
select {
case <-rf.maybeRequestNewCandidates:
continue
case pi, ok := <-peerChan:
if !ok {
wg.Wait()
peerChan = nil
continue
}
log.Debugw("found node", "id", pi.ID)
rf.candidateMx.Lock()
numCandidates := len(rf.candidates)
backoffStart, isOnBackoff := rf.backoff[pi.ID]
rf.candidateMx.Unlock()
if isOnBackoff {
log.Debugw("skipping node that we recently failed to obtain a reservation with", "id", pi.ID, "last attempt", rf.conf.clock.Since(backoffStart))
continue
}
if numCandidates >= rf.conf.maxCandidates {
log.Debugw("skipping node. Already have enough candidates", "id", pi.ID, "num", numCandidates, "max", rf.conf.maxCandidates)
continue
}
rf.refCount.Add(1)
wg.Add(1)
go func() {
defer rf.refCount.Done()
defer wg.Done()
if added := rf.handleNewNode(ctx, pi); added {
rf.notifyNewCandidate()
}
}()
case <-ctx.Done():
rf.metricsTracer.CandidateLoopState(stopped)
return
}
}
}
func (rf *relayFinder) notifyMaybeConnectToRelay() {
select {
case rf.maybeConnectToRelayTrigger <- struct{}{}:
default:
}
}
func (rf *relayFinder) notifyMaybeNeedNewCandidates() {
select {
case rf.maybeRequestNewCandidates <- struct{}{}:
default:
}
}
func (rf *relayFinder) notifyNewCandidate() {
select {
case rf.candidateFound <- struct{}{}:
default:
}
}
// handleNewNode tests if a peer supports circuit v2.
// This method is only run on private nodes.
// If a peer does, it is added to the candidates map.
// Note that just supporting the protocol doesn't guarantee that we can also obtain a reservation.
func (rf *relayFinder) handleNewNode(ctx context.Context, pi peer.AddrInfo) (added bool) {
rf.relayMx.Lock()
relayInUse := rf.usingRelay(pi.ID)
rf.relayMx.Unlock()
if relayInUse {
return false
}
ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
supportsV2, err := rf.tryNode(ctx, pi)
if err != nil {
log.Debugf("node %s not accepted as a candidate: %s", pi.ID, err)
if err == errProtocolNotSupported {
rf.metricsTracer.CandidateChecked(false)
}
return false
}
rf.metricsTracer.CandidateChecked(true)
rf.candidateMx.Lock()
if len(rf.candidates) > rf.conf.maxCandidates {
rf.candidateMx.Unlock()
return false
}
log.Debugw("node supports relay protocol", "peer", pi.ID, "supports circuit v2", supportsV2)
rf.addCandidate(&candidate{
added: rf.conf.clock.Now(),
ai: pi,
supportsRelayV2: supportsV2,
})
rf.candidateMx.Unlock()
return true
}
var errProtocolNotSupported = errors.New("doesn't speak circuit v2")
// tryNode checks if a peer actually supports either circuit v2.
// It does not modify any internal state.
func (rf *relayFinder) tryNode(ctx context.Context, pi peer.AddrInfo) (supportsRelayV2 bool, err error) {
if err := rf.host.Connect(ctx, pi); err != nil {
return false, fmt.Errorf("error connecting to relay %s: %w", pi.ID, err)
}
conns := rf.host.Network().ConnsToPeer(pi.ID)
for _, conn := range conns {
if isRelayAddr(conn.RemoteMultiaddr()) {
return false, errors.New("not a public node")
}
}
// wait for identify to complete in at least one conn so that we can check the supported protocols
ready := make(chan struct{}, 1)
for _, conn := range conns {
go func(conn network.Conn) {
select {
case <-rf.host.IDService().IdentifyWait(conn):
select {
case ready <- struct{}{}:
default:
}
case <-ctx.Done():
}
}(conn)
}
select {
case <-ready:
case <-ctx.Done():
return false, ctx.Err()
}
protos, err := rf.host.Peerstore().SupportsProtocols(pi.ID, protoIDv2)
if err != nil {
return false, fmt.Errorf("error checking relay protocol support for peer %s: %w", pi.ID, err)
}
if len(protos) == 0 {
return false, errProtocolNotSupported
}
return true, nil
}
// When a new node that could be a relay is found, we receive a notification on the maybeConnectToRelayTrigger chan.
// This function makes sure that we only run one instance of maybeConnectToRelay at once, and buffers
// exactly one more trigger event to run maybeConnectToRelay.
func (rf *relayFinder) handleNewCandidates(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-rf.maybeConnectToRelayTrigger:
rf.maybeConnectToRelay(ctx)
}
}
}
func (rf *relayFinder) maybeConnectToRelay(ctx context.Context) {
rf.relayMx.Lock()
numRelays := len(rf.relays)
rf.relayMx.Unlock()
// We're already connected to our desired number of relays. Nothing to do here.
if numRelays == rf.conf.desiredRelays {
return
}
rf.candidateMx.Lock()
if len(rf.relays) == 0 && len(rf.candidates) < rf.conf.minCandidates && rf.conf.clock.Since(rf.bootTime) < rf.conf.bootDelay {
// During the startup phase, we don't want to connect to the first candidate that we find.
// Instead, we wait until we've found at least minCandidates, and then select the best of those.
// However, if that takes too long (longer than bootDelay), we still go ahead.
rf.candidateMx.Unlock()
return
}
if len(rf.candidates) == 0 {
rf.candidateMx.Unlock()
return
}
candidates := rf.selectCandidates()
rf.candidateMx.Unlock()
// We now iterate over the candidates, attempting (sequentially) to get reservations with them, until
// we reach the desired number of relays.
for _, cand := range candidates {
id := cand.ai.ID
rf.relayMx.Lock()
usingRelay := rf.usingRelay(id)
rf.relayMx.Unlock()
if usingRelay {
rf.candidateMx.Lock()
rf.removeCandidate(id)
rf.candidateMx.Unlock()
rf.notifyMaybeNeedNewCandidates()
continue
}
rsvp, err := rf.connectToRelay(ctx, cand)
if err != nil {
log.Debugw("failed to connect to relay", "peer", id, "error", err)
rf.notifyMaybeNeedNewCandidates()
rf.metricsTracer.ReservationRequestFinished(false, err)
continue
}
log.Debugw("adding new relay", "id", id)
rf.relayMx.Lock()
rf.relays[id] = rsvp
numRelays := len(rf.relays)
rf.relayMx.Unlock()
rf.notifyMaybeNeedNewCandidates()
rf.host.ConnManager().Protect(id, autorelayTag) // protect the connection
select {
case rf.relayUpdated <- struct{}{}:
default:
}
rf.metricsTracer.ReservationRequestFinished(false, nil)
if numRelays >= rf.conf.desiredRelays {
break
}
}
}
func (rf *relayFinder) connectToRelay(ctx context.Context, cand *candidate) (*circuitv2.Reservation, error) {
id := cand.ai.ID
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var rsvp *circuitv2.Reservation
// make sure we're still connected.
if rf.host.Network().Connectedness(id) != network.Connected {
if err := rf.host.Connect(ctx, cand.ai); err != nil {
rf.candidateMx.Lock()
rf.removeCandidate(cand.ai.ID)
rf.candidateMx.Unlock()
return nil, fmt.Errorf("failed to connect: %w", err)
}
}
rf.candidateMx.Lock()
rf.backoff[id] = rf.conf.clock.Now()
rf.candidateMx.Unlock()
var err error
if cand.supportsRelayV2 {
rsvp, err = circuitv2.Reserve(ctx, rf.host, cand.ai)
if err != nil {
err = fmt.Errorf("failed to reserve slot: %w", err)
}
}
rf.candidateMx.Lock()
rf.removeCandidate(id)
rf.candidateMx.Unlock()
return rsvp, err
}
func (rf *relayFinder) refreshReservations(ctx context.Context, now time.Time) bool {
rf.relayMx.Lock()
// find reservations about to expire and refresh them in parallel
g := new(errgroup.Group)
for p, rsvp := range rf.relays {
if now.Add(rsvpExpirationSlack).Before(rsvp.Expiration) {
continue
}
p := p
g.Go(func() error {
err := rf.refreshRelayReservation(ctx, p)
rf.metricsTracer.ReservationRequestFinished(true, err)
return err
})
}
rf.relayMx.Unlock()
err := g.Wait()
return err != nil
}
func (rf *relayFinder) refreshRelayReservation(ctx context.Context, p peer.ID) error {
rsvp, err := circuitv2.Reserve(ctx, rf.host, peer.AddrInfo{ID: p})
rf.relayMx.Lock()
if err != nil {
log.Debugw("failed to refresh relay slot reservation", "relay", p, "error", err)
_, exists := rf.relays[p]
delete(rf.relays, p)
// unprotect the connection
rf.host.ConnManager().Unprotect(p, autorelayTag)
rf.relayMx.Unlock()
if exists {
rf.metricsTracer.ReservationEnded(1)
}
return err
}
log.Debugw("refreshed relay slot reservation", "relay", p)
rf.relays[p] = rsvp
rf.relayMx.Unlock()
return nil
}
// usingRelay returns if we're currently using the given relay.
func (rf *relayFinder) usingRelay(p peer.ID) bool {
_, ok := rf.relays[p]
return ok
}
// addCandidates adds a candidate to the candidates set. Assumes caller holds candidateMx mutex
func (rf *relayFinder) addCandidate(cand *candidate) {
_, exists := rf.candidates[cand.ai.ID]
rf.candidates[cand.ai.ID] = cand
if !exists {
rf.metricsTracer.CandidateAdded(1)
}
}
func (rf *relayFinder) removeCandidate(id peer.ID) {
_, exists := rf.candidates[id]
if exists {
delete(rf.candidates, id)
rf.metricsTracer.CandidateRemoved(1)
}
}
// selectCandidates returns an ordered slice of relay candidates.
// Callers should attempt to obtain reservations with the candidates in this order.
func (rf *relayFinder) selectCandidates() []*candidate {
now := rf.conf.clock.Now()
candidates := make([]*candidate, 0, len(rf.candidates))
for _, cand := range rf.candidates {
if cand.added.Add(rf.conf.maxCandidateAge).After(now) {
candidates = append(candidates, cand)
}
}
// TODO: better relay selection strategy; this just selects random relays,
// but we should probably use ping latency as the selection metric
rand.Shuffle(len(candidates), func(i, j int) {
candidates[i], candidates[j] = candidates[j], candidates[i]
})
return candidates
}
// This function is computes the NATed relay addrs when our status is private:
// - The public addrs are removed from the address set.
// - The non-public addrs are included verbatim so that peers behind the same NAT/firewall
// can still dial us directly.
// - On top of those, we add the relay-specific addrs for the relays to which we are
// connected. For each non-private relay addr, we encapsulate the p2p-circuit addr
// through which we can be dialed.
func (rf *relayFinder) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
rf.relayMx.Lock()
defer rf.relayMx.Unlock()
if rf.cachedAddrs != nil && rf.conf.clock.Now().Before(rf.cachedAddrsExpiry) {
return rf.cachedAddrs
}
raddrs := make([]ma.Multiaddr, 0, 4*len(rf.relays)+4)
// only keep private addrs from the original addr set
for _, addr := range addrs {
if manet.IsPrivateAddr(addr) {
raddrs = append(raddrs, addr)
}
}
// add relay specific addrs to the list
relayAddrCnt := 0
for p := range rf.relays {
addrs := cleanupAddressSet(rf.host.Peerstore().Addrs(p))
relayAddrCnt += len(addrs)
circuit := ma.StringCast(fmt.Sprintf("/p2p/%s/p2p-circuit", p.Pretty()))
for _, addr := range addrs {
pub := addr.Encapsulate(circuit)
raddrs = append(raddrs, pub)
}
}
rf.cachedAddrs = raddrs
rf.cachedAddrsExpiry = rf.conf.clock.Now().Add(30 * time.Second)
rf.metricsTracer.RelayAddressCount(relayAddrCnt)
return raddrs
}
func (rf *relayFinder) Start() error {
rf.ctxCancelMx.Lock()
defer rf.ctxCancelMx.Unlock()
if rf.ctxCancel != nil {
return errAlreadyRunning
}
log.Debug("starting relay finder")
rf.initMetrics()
ctx, cancel := context.WithCancel(context.Background())
rf.ctxCancel = cancel
rf.refCount.Add(1)
go func() {
defer rf.refCount.Done()
rf.background(ctx)
}()
return nil
}
func (rf *relayFinder) Stop() error {
rf.ctxCancelMx.Lock()
defer rf.ctxCancelMx.Unlock()
log.Debug("stopping relay finder")
if rf.ctxCancel != nil {
rf.ctxCancel()
}
rf.refCount.Wait()
rf.ctxCancel = nil
rf.resetMetrics()
return nil
}
func (rf *relayFinder) initMetrics() {
rf.metricsTracer.DesiredReservations(rf.conf.desiredRelays)
rf.relayMx.Lock()
rf.metricsTracer.ReservationOpened(len(rf.relays))
rf.relayMx.Unlock()
rf.candidateMx.Lock()
rf.metricsTracer.CandidateAdded(len(rf.candidates))
rf.candidateMx.Unlock()
}
func (rf *relayFinder) resetMetrics() {
rf.relayMx.Lock()
rf.metricsTracer.ReservationEnded(len(rf.relays))
rf.relayMx.Unlock()
rf.candidateMx.Lock()
rf.metricsTracer.CandidateRemoved(len(rf.candidates))
rf.candidateMx.Unlock()
rf.metricsTracer.RelayAddressCount(0)
rf.metricsTracer.ScheduledWorkUpdated(&scheduledWorkTimes{})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
//go:build gomock || generate
package basichost
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT"
type NAT nat

View File

@@ -0,0 +1,299 @@
package basichost
import (
"context"
"io"
"net"
"net/netip"
"strconv"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
inat "github.com/libp2p/go-libp2p/p2p/net/nat"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// NATManager is a simple interface to manage NAT devices.
// It listens Listen and ListenClose notifications from the network.Network,
// and tries to obtain port mappings for those.
type NATManager interface {
GetMapping(ma.Multiaddr) ma.Multiaddr
HasDiscoveredNAT() bool
io.Closer
}
// NewNATManager creates a NAT manager.
func NewNATManager(net network.Network) NATManager {
return newNATManager(net)
}
type entry struct {
protocol string
port int
}
type nat interface {
AddMapping(ctx context.Context, protocol string, port int) error
RemoveMapping(ctx context.Context, protocol string, port int) error
GetMapping(protocol string, port int) (netip.AddrPort, bool)
io.Closer
}
// so we can mock it in tests
var discoverNAT = func(ctx context.Context) (nat, error) { return inat.DiscoverNAT(ctx) }
// natManager takes care of adding + removing port mappings to the nat.
// Initialized with the host if it has a NATPortMap option enabled.
// natManager receives signals from the network, and check on nat mappings:
// - natManager listens to the network and adds or closes port mappings
// as the network signals Listen() or ListenClose().
// - closing the natManager closes the nat and its mappings.
type natManager struct {
net network.Network
natMx sync.RWMutex
nat nat
syncFlag chan struct{} // cap: 1
tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function
refCount sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func newNATManager(net network.Network) *natManager {
ctx, cancel := context.WithCancel(context.Background())
nmgr := &natManager{
net: net,
syncFlag: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: cancel,
tracked: make(map[entry]bool),
}
nmgr.refCount.Add(1)
go nmgr.background(ctx)
return nmgr
}
// Close closes the natManager, closing the underlying nat
// and unregistering from network events.
func (nmgr *natManager) Close() error {
nmgr.ctxCancel()
nmgr.refCount.Wait()
return nil
}
func (nmgr *natManager) HasDiscoveredNAT() bool {
nmgr.natMx.RLock()
defer nmgr.natMx.RUnlock()
return nmgr.nat != nil
}
func (nmgr *natManager) background(ctx context.Context) {
defer nmgr.refCount.Done()
defer func() {
nmgr.natMx.Lock()
defer nmgr.natMx.Unlock()
if nmgr.nat != nil {
nmgr.nat.Close()
}
}()
discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
natInstance, err := discoverNAT(discoverCtx)
if err != nil {
log.Info("DiscoverNAT error:", err)
return
}
nmgr.natMx.Lock()
nmgr.nat = natInstance
nmgr.natMx.Unlock()
// sign natManager up for network notifications
// we need to sign up here to avoid missing some notifs
// before the NAT has been found.
nmgr.net.Notify((*nmgrNetNotifiee)(nmgr))
defer nmgr.net.StopNotify((*nmgrNetNotifiee)(nmgr))
nmgr.doSync() // sync one first.
for {
select {
case <-nmgr.syncFlag:
nmgr.doSync() // sync when our listen addresses chnage.
case <-ctx.Done():
return
}
}
}
func (nmgr *natManager) sync() {
select {
case nmgr.syncFlag <- struct{}{}:
default:
}
}
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
// new mappings.
func (nmgr *natManager) doSync() {
for e := range nmgr.tracked {
nmgr.tracked[e] = false
}
var newAddresses []entry
for _, maddr := range nmgr.net.ListenAddresses() {
// Strip the IP
maIP, rest := ma.SplitFirst(maddr)
if maIP == nil || rest == nil {
continue
}
switch maIP.Protocol().Code {
case ma.P_IP6, ma.P_IP4:
default:
continue
}
// Only bother if we're listening on an unicast / unspecified IP.
ip := net.IP(maIP.RawValue())
if !ip.IsGlobalUnicast() && !ip.IsUnspecified() {
continue
}
// Extract the port/protocol
proto, _ := ma.SplitFirst(rest)
if proto == nil {
continue
}
var protocol string
switch proto.Protocol().Code {
case ma.P_TCP:
protocol = "tcp"
case ma.P_UDP:
protocol = "udp"
default:
continue
}
port, err := strconv.ParseUint(proto.Value(), 10, 16)
if err != nil {
// bug in multiaddr
panic(err)
}
e := entry{protocol: protocol, port: int(port)}
if _, ok := nmgr.tracked[e]; ok {
nmgr.tracked[e] = true
} else {
newAddresses = append(newAddresses, e)
}
}
var wg sync.WaitGroup
defer wg.Wait()
// Close old mappings
for e, v := range nmgr.tracked {
if !v {
nmgr.nat.RemoveMapping(nmgr.ctx, e.protocol, e.port)
delete(nmgr.tracked, e)
}
}
// Create new mappings.
for _, e := range newAddresses {
if err := nmgr.nat.AddMapping(nmgr.ctx, e.protocol, e.port); err != nil {
log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err)
}
nmgr.tracked[e] = false
}
}
func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
nmgr.natMx.Lock()
defer nmgr.natMx.Unlock()
if nmgr.nat == nil { // NAT not yet initialized
return nil
}
var found bool
var proto int // ma.P_TCP or ma.P_UDP
transport, rest := ma.SplitFunc(addr, func(c ma.Component) bool {
if found {
return true
}
proto = c.Protocol().Code
found = proto == ma.P_TCP || proto == ma.P_UDP
return false
})
if !manet.IsThinWaist(transport) {
return nil
}
naddr, err := manet.ToNetAddr(transport)
if err != nil {
log.Error("error parsing net multiaddr %q: %s", transport, err)
return nil
}
var (
ip net.IP
port int
protocol string
)
switch naddr := naddr.(type) {
case *net.TCPAddr:
ip = naddr.IP
port = naddr.Port
protocol = "tcp"
case *net.UDPAddr:
ip = naddr.IP
port = naddr.Port
protocol = "udp"
default:
return nil
}
if !ip.IsGlobalUnicast() && !ip.IsUnspecified() {
// We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc.
return nil
}
extAddr, ok := nmgr.nat.GetMapping(protocol, port)
if !ok {
return nil
}
var mappedAddr net.Addr
switch naddr.(type) {
case *net.TCPAddr:
mappedAddr = net.TCPAddrFromAddrPort(extAddr)
case *net.UDPAddr:
mappedAddr = net.UDPAddrFromAddrPort(extAddr)
}
mappedMaddr, err := manet.FromNetAddr(mappedAddr)
if err != nil {
log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err)
return nil
}
extMaddr := mappedMaddr
if rest != nil {
extMaddr = ma.Join(extMaddr, rest)
}
return extMaddr
}
type nmgrNetNotifiee natManager
func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) }
func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() }
func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { nn.natManager().sync() }
func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {}
func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {}

View File

@@ -0,0 +1,232 @@
package blankhost
import (
"context"
"errors"
"fmt"
"io"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/record"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
mstream "github.com/multiformats/go-multistream"
)
var log = logging.Logger("blankhost")
// BlankHost is the thinnest implementation of the host.Host interface
type BlankHost struct {
n network.Network
mux *mstream.MultistreamMuxer[protocol.ID]
cmgr connmgr.ConnManager
eventbus event.Bus
emitters struct {
evtLocalProtocolsUpdated event.Emitter
}
}
type config struct {
cmgr connmgr.ConnManager
eventBus event.Bus
}
type Option = func(cfg *config)
func WithConnectionManager(cmgr connmgr.ConnManager) Option {
return func(cfg *config) {
cfg.cmgr = cmgr
}
}
func WithEventBus(eventBus event.Bus) Option {
return func(cfg *config) {
cfg.eventBus = eventBus
}
}
func NewBlankHost(n network.Network, options ...Option) *BlankHost {
cfg := config{
cmgr: &connmgr.NullConnMgr{},
}
for _, opt := range options {
opt(&cfg)
}
bh := &BlankHost{
n: n,
cmgr: cfg.cmgr,
mux: mstream.NewMultistreamMuxer[protocol.ID](),
}
if bh.eventbus == nil {
bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer()))
}
// subscribe the connection manager to network notifications (has no effect with NullConnMgr)
n.Notify(bh.cmgr.Notifee())
var err error
if bh.emitters.evtLocalProtocolsUpdated, err = bh.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}); err != nil {
return nil
}
n.SetStreamHandler(bh.newStreamHandler)
// persist a signed peer record for self to the peerstore.
if err := bh.initSignedRecord(); err != nil {
log.Errorf("error creating blank host, err=%s", err)
return nil
}
return bh
}
func (bh *BlankHost) initSignedRecord() error {
cab, ok := peerstore.GetCertifiedAddrBook(bh.n.Peerstore())
if !ok {
log.Error("peerstore does not support signed records")
return errors.New("peerstore does not support signed records")
}
rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{ID: bh.ID(), Addrs: bh.Addrs()})
ev, err := record.Seal(rec, bh.Peerstore().PrivKey(bh.ID()))
if err != nil {
log.Errorf("failed to create signed record for self, err=%s", err)
return fmt.Errorf("failed to create signed record for self, err=%s", err)
}
_, err = cab.ConsumePeerRecord(ev, peerstore.PermanentAddrTTL)
if err != nil {
log.Errorf("failed to persist signed record to peerstore,err=%s", err)
return fmt.Errorf("failed to persist signed record for self, err=%s", err)
}
return err
}
var _ host.Host = (*BlankHost)(nil)
func (bh *BlankHost) Addrs() []ma.Multiaddr {
addrs, err := bh.n.InterfaceListenAddresses()
if err != nil {
log.Debug("error retrieving network interface addrs: ", err)
return nil
}
return addrs
}
func (bh *BlankHost) Close() error {
return bh.n.Close()
}
func (bh *BlankHost) Connect(ctx context.Context, ai peer.AddrInfo) error {
// absorb addresses into peerstore
bh.Peerstore().AddAddrs(ai.ID, ai.Addrs, peerstore.TempAddrTTL)
cs := bh.n.ConnsToPeer(ai.ID)
if len(cs) > 0 {
return nil
}
_, err := bh.Network().DialPeer(ctx, ai.ID)
if err != nil {
return fmt.Errorf("failed to dial: %w", err)
}
return err
}
func (bh *BlankHost) Peerstore() peerstore.Peerstore {
return bh.n.Peerstore()
}
func (bh *BlankHost) ID() peer.ID {
return bh.n.LocalPeer()
}
func (bh *BlankHost) NewStream(ctx context.Context, p peer.ID, protos ...protocol.ID) (network.Stream, error) {
s, err := bh.n.NewStream(ctx, p)
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
selected, err := mstream.SelectOneOf(protos, s)
if err != nil {
s.Reset()
return nil, fmt.Errorf("failed to negotiate protocol: %w", err)
}
s.SetProtocol(selected)
bh.Peerstore().AddProtocols(p, selected)
return s, nil
}
func (bh *BlankHost) RemoveStreamHandler(pid protocol.ID) {
bh.Mux().RemoveHandler(pid)
bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Removed: []protocol.ID{pid},
})
}
func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
bh.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(p)
handler(is)
return nil
})
bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Added: []protocol.ID{pid},
})
}
func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
bh.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(p)
handler(is)
return nil
})
bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Added: []protocol.ID{pid},
})
}
// newStreamHandler is the remote-opened stream handler for network.Network
func (bh *BlankHost) newStreamHandler(s network.Stream) {
protoID, handle, err := bh.Mux().Negotiate(s)
if err != nil {
log.Infow("protocol negotiation failed", "error", err)
s.Reset()
return
}
s.SetProtocol(protoID)
go handle(protoID, s)
}
// TODO: i'm not sure this really needs to be here
func (bh *BlankHost) Mux() protocol.Switch {
return bh.mux
}
// TODO: also not sure this fits... Might be better ways around this (leaky abstractions)
func (bh *BlankHost) Network() network.Network {
return bh.n
}
func (bh *BlankHost) ConnManager() connmgr.ConnManager {
return bh.cmgr
}
func (bh *BlankHost) EventBus() event.Bus {
return bh.eventbus
}

View File

@@ -0,0 +1,418 @@
package eventbus
import (
"errors"
"fmt"
"reflect"
"sync"
"sync/atomic"
"github.com/libp2p/go-libp2p/core/event"
)
// /////////////////////
// BUS
// basicBus is a type-based event delivery system
type basicBus struct {
lk sync.RWMutex
nodes map[reflect.Type]*node
wildcard *wildcardNode
metricsTracer MetricsTracer
}
var _ event.Bus = (*basicBus)(nil)
type emitter struct {
n *node
w *wildcardNode
typ reflect.Type
closed atomic.Bool
dropper func(reflect.Type)
metricsTracer MetricsTracer
}
func (e *emitter) Emit(evt interface{}) error {
if e.closed.Load() {
return fmt.Errorf("emitter is closed")
}
e.n.emit(evt)
e.w.emit(evt)
if e.metricsTracer != nil {
e.metricsTracer.EventEmitted(e.typ)
}
return nil
}
func (e *emitter) Close() error {
if !e.closed.CompareAndSwap(false, true) {
return fmt.Errorf("closed an emitter more than once")
}
if e.n.nEmitters.Add(-1) == 0 {
e.dropper(e.typ)
}
return nil
}
func NewBus(opts ...Option) event.Bus {
bus := &basicBus{
nodes: map[reflect.Type]*node{},
wildcard: &wildcardNode{},
}
for _, opt := range opts {
opt(bus)
}
return bus
}
func (b *basicBus) withNode(typ reflect.Type, cb func(*node), async func(*node)) {
b.lk.Lock()
n, ok := b.nodes[typ]
if !ok {
n = newNode(typ, b.metricsTracer)
b.nodes[typ] = n
}
n.lk.Lock()
b.lk.Unlock()
cb(n)
if async == nil {
n.lk.Unlock()
} else {
go func() {
defer n.lk.Unlock()
async(n)
}()
}
}
func (b *basicBus) tryDropNode(typ reflect.Type) {
b.lk.Lock()
n, ok := b.nodes[typ]
if !ok { // already dropped
b.lk.Unlock()
return
}
n.lk.Lock()
if n.nEmitters.Load() > 0 || len(n.sinks) > 0 {
n.lk.Unlock()
b.lk.Unlock()
return // still in use
}
n.lk.Unlock()
delete(b.nodes, typ)
b.lk.Unlock()
}
type wildcardSub struct {
ch chan interface{}
w *wildcardNode
metricsTracer MetricsTracer
name string
}
func (w *wildcardSub) Out() <-chan interface{} {
return w.ch
}
func (w *wildcardSub) Close() error {
w.w.removeSink(w.ch)
if w.metricsTracer != nil {
w.metricsTracer.RemoveSubscriber(reflect.TypeOf(event.WildcardSubscription))
}
return nil
}
func (w *wildcardSub) Name() string {
return w.name
}
type namedSink struct {
name string
ch chan interface{}
}
type sub struct {
ch chan interface{}
nodes []*node
dropper func(reflect.Type)
metricsTracer MetricsTracer
name string
}
func (s *sub) Name() string {
return s.name
}
func (s *sub) Out() <-chan interface{} {
return s.ch
}
func (s *sub) Close() error {
go func() {
// drain the event channel, will return when closed and drained.
// this is necessary to unblock publishes to this channel.
for range s.ch {
}
}()
for _, n := range s.nodes {
n.lk.Lock()
for i := 0; i < len(n.sinks); i++ {
if n.sinks[i].ch == s.ch {
n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil
n.sinks = n.sinks[:len(n.sinks)-1]
if s.metricsTracer != nil {
s.metricsTracer.RemoveSubscriber(n.typ)
}
break
}
}
tryDrop := len(n.sinks) == 0 && n.nEmitters.Load() == 0
n.lk.Unlock()
if tryDrop {
s.dropper(n.typ)
}
}
close(s.ch)
return nil
}
var _ event.Subscription = (*sub)(nil)
// Subscribe creates new subscription. Failing to drain the channel will cause
// publishers to get blocked. CancelFunc is guaranteed to return after last send
// to the channel
func (b *basicBus) Subscribe(evtTypes interface{}, opts ...event.SubscriptionOpt) (_ event.Subscription, err error) {
settings := newSubSettings()
for _, opt := range opts {
if err := opt(&settings); err != nil {
return nil, err
}
}
if evtTypes == event.WildcardSubscription {
out := &wildcardSub{
ch: make(chan interface{}, settings.buffer),
w: b.wildcard,
metricsTracer: b.metricsTracer,
name: settings.name,
}
b.wildcard.addSink(&namedSink{ch: out.ch, name: out.name})
return out, nil
}
types, ok := evtTypes.([]interface{})
if !ok {
types = []interface{}{evtTypes}
}
if len(types) > 1 {
for _, t := range types {
if t == event.WildcardSubscription {
return nil, fmt.Errorf("wildcard subscriptions must be started separately")
}
}
}
out := &sub{
ch: make(chan interface{}, settings.buffer),
nodes: make([]*node, len(types)),
dropper: b.tryDropNode,
metricsTracer: b.metricsTracer,
name: settings.name,
}
for _, etyp := range types {
if reflect.TypeOf(etyp).Kind() != reflect.Ptr {
return nil, errors.New("subscribe called with non-pointer type")
}
}
for i, etyp := range types {
typ := reflect.TypeOf(etyp)
b.withNode(typ.Elem(), func(n *node) {
n.sinks = append(n.sinks, &namedSink{ch: out.ch, name: out.name})
out.nodes[i] = n
if b.metricsTracer != nil {
b.metricsTracer.AddSubscriber(typ.Elem())
}
}, func(n *node) {
if n.keepLast {
l := n.last
if l == nil {
return
}
out.ch <- l
}
})
}
return out, nil
}
// Emitter creates new emitter
//
// eventType accepts typed nil pointers, and uses the type information to
// select output type
//
// Example:
// emit, err := eventbus.Emitter(new(EventT))
// defer emit.Close() // MUST call this after being done with the emitter
//
// emit(EventT{})
func (b *basicBus) Emitter(evtType interface{}, opts ...event.EmitterOpt) (e event.Emitter, err error) {
if evtType == event.WildcardSubscription {
return nil, fmt.Errorf("illegal emitter for wildcard subscription")
}
var settings emitterSettings
for _, opt := range opts {
if err := opt(&settings); err != nil {
return nil, err
}
}
typ := reflect.TypeOf(evtType)
if typ.Kind() != reflect.Ptr {
return nil, errors.New("emitter called with non-pointer type")
}
typ = typ.Elem()
b.withNode(typ, func(n *node) {
n.nEmitters.Add(1)
n.keepLast = n.keepLast || settings.makeStateful
e = &emitter{n: n, typ: typ, dropper: b.tryDropNode, w: b.wildcard, metricsTracer: b.metricsTracer}
}, nil)
return
}
// GetAllEventTypes returns all the event types that this bus has emitters
// or subscribers for.
func (b *basicBus) GetAllEventTypes() []reflect.Type {
b.lk.RLock()
defer b.lk.RUnlock()
types := make([]reflect.Type, 0, len(b.nodes))
for t := range b.nodes {
types = append(types, t)
}
return types
}
// /////////////////////
// NODE
type wildcardNode struct {
sync.RWMutex
nSinks atomic.Int32
sinks []*namedSink
metricsTracer MetricsTracer
}
func (n *wildcardNode) addSink(sink *namedSink) {
n.nSinks.Add(1) // ok to do outside the lock
n.Lock()
n.sinks = append(n.sinks, sink)
n.Unlock()
if n.metricsTracer != nil {
n.metricsTracer.AddSubscriber(reflect.TypeOf(event.WildcardSubscription))
}
}
func (n *wildcardNode) removeSink(ch chan interface{}) {
n.nSinks.Add(-1) // ok to do outside the lock
n.Lock()
for i := 0; i < len(n.sinks); i++ {
if n.sinks[i].ch == ch {
n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil
n.sinks = n.sinks[:len(n.sinks)-1]
break
}
}
n.Unlock()
}
func (n *wildcardNode) emit(evt interface{}) {
if n.nSinks.Load() == 0 {
return
}
n.RLock()
for _, sink := range n.sinks {
// Sending metrics before sending on channel allows us to
// record channel full events before blocking
sendSubscriberMetrics(n.metricsTracer, sink)
sink.ch <- evt
}
n.RUnlock()
}
type node struct {
// Note: make sure to NEVER lock basicBus.lk when this lock is held
lk sync.Mutex
typ reflect.Type
// emitter ref count
nEmitters atomic.Int32
keepLast bool
last interface{}
sinks []*namedSink
metricsTracer MetricsTracer
}
func newNode(typ reflect.Type, metricsTracer MetricsTracer) *node {
return &node{
typ: typ,
metricsTracer: metricsTracer,
}
}
func (n *node) emit(evt interface{}) {
typ := reflect.TypeOf(evt)
if typ != n.typ {
panic(fmt.Sprintf("Emit called with wrong type. expected: %s, got: %s", n.typ, typ))
}
n.lk.Lock()
if n.keepLast {
n.last = evt
}
for _, sink := range n.sinks {
// Sending metrics before sending on channel allows us to
// record channel full events before blocking
sendSubscriberMetrics(n.metricsTracer, sink)
sink.ch <- evt
}
n.lk.Unlock()
}
func sendSubscriberMetrics(metricsTracer MetricsTracer, sink *namedSink) {
if metricsTracer != nil {
metricsTracer.SubscriberQueueLength(sink.name, len(sink.ch)+1)
metricsTracer.SubscriberQueueFull(sink.name, len(sink.ch)+1 >= cap(sink.ch))
metricsTracer.SubscriberEventQueued(sink.name)
}
}

View File

@@ -0,0 +1,164 @@
package eventbus
import (
"reflect"
"strings"
"github.com/libp2p/go-libp2p/p2p/metricshelper"
"github.com/prometheus/client_golang/prometheus"
)
const metricNamespace = "libp2p_eventbus"
var (
eventsEmitted = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "events_emitted_total",
Help: "Events Emitted",
},
[]string{"event"},
)
totalSubscribers = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "subscribers_total",
Help: "Number of subscribers for an event type",
},
[]string{"event"},
)
subscriberQueueLength = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "subscriber_queue_length",
Help: "Subscriber queue length",
},
[]string{"subscriber_name"},
)
subscriberQueueFull = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "subscriber_queue_full",
Help: "Subscriber Queue completely full",
},
[]string{"subscriber_name"},
)
subscriberEventQueued = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "subscriber_event_queued",
Help: "Event Queued for subscriber",
},
[]string{"subscriber_name"},
)
collectors = []prometheus.Collector{
eventsEmitted,
totalSubscribers,
subscriberQueueLength,
subscriberQueueFull,
subscriberEventQueued,
}
)
// MetricsTracer tracks metrics for the eventbus subsystem
type MetricsTracer interface {
// EventEmitted counts the total number of events grouped by event type
EventEmitted(typ reflect.Type)
// AddSubscriber adds a subscriber for the event type
AddSubscriber(typ reflect.Type)
// RemoveSubscriber removes a subscriber for the event type
RemoveSubscriber(typ reflect.Type)
// SubscriberQueueLength is the length of the subscribers channel
SubscriberQueueLength(name string, n int)
// SubscriberQueueFull tracks whether a subscribers channel if full
SubscriberQueueFull(name string, isFull bool)
// SubscriberEventQueued counts the total number of events grouped by subscriber
SubscriberEventQueued(name string)
}
type metricsTracer struct{}
var _ MetricsTracer = &metricsTracer{}
type metricsTracerSetting struct {
reg prometheus.Registerer
}
type MetricsTracerOption func(*metricsTracerSetting)
func WithRegisterer(reg prometheus.Registerer) MetricsTracerOption {
return func(s *metricsTracerSetting) {
if reg != nil {
s.reg = reg
}
}
}
func NewMetricsTracer(opts ...MetricsTracerOption) MetricsTracer {
setting := &metricsTracerSetting{reg: prometheus.DefaultRegisterer}
for _, opt := range opts {
opt(setting)
}
metricshelper.RegisterCollectors(setting.reg, collectors...)
return &metricsTracer{}
}
func (m *metricsTracer) EventEmitted(typ reflect.Type) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, strings.TrimPrefix(typ.String(), "event."))
eventsEmitted.WithLabelValues(*tags...).Inc()
}
func (m *metricsTracer) AddSubscriber(typ reflect.Type) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, strings.TrimPrefix(typ.String(), "event."))
totalSubscribers.WithLabelValues(*tags...).Inc()
}
func (m *metricsTracer) RemoveSubscriber(typ reflect.Type) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, strings.TrimPrefix(typ.String(), "event."))
totalSubscribers.WithLabelValues(*tags...).Dec()
}
func (m *metricsTracer) SubscriberQueueLength(name string, n int) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, name)
subscriberQueueLength.WithLabelValues(*tags...).Set(float64(n))
}
func (m *metricsTracer) SubscriberQueueFull(name string, isFull bool) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, name)
observer := subscriberQueueFull.WithLabelValues(*tags...)
if isFull {
observer.Set(1)
} else {
observer.Set(0)
}
}
func (m *metricsTracer) SubscriberEventQueued(name string) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, name)
subscriberEventQueued.WithLabelValues(*tags...).Inc()
}

View File

@@ -0,0 +1,79 @@
package eventbus
import (
"fmt"
"runtime"
"strings"
"sync/atomic"
)
type subSettings struct {
buffer int
name string
}
var subCnt atomic.Int64
var subSettingsDefault = subSettings{
buffer: 16,
}
// newSubSettings returns the settings for a new subscriber
// The default naming strategy is sub-<fileName>-L<lineNum>
func newSubSettings() subSettings {
settings := subSettingsDefault
_, file, line, ok := runtime.Caller(2) // skip=1 is eventbus.Subscriber
if ok {
file = strings.TrimPrefix(file, "github.com/")
// remove the version number from the path, for example
// go-libp2p-package@v0.x.y-some-hash-123/file.go will be shortened go go-libp2p-package/file.go
if idx1 := strings.Index(file, "@"); idx1 != -1 {
if idx2 := strings.Index(file[idx1:], "/"); idx2 != -1 {
file = file[:idx1] + file[idx1+idx2:]
}
}
settings.name = fmt.Sprintf("%s-L%d", file, line)
} else {
settings.name = fmt.Sprintf("subscriber-%d", subCnt.Add(1))
}
return settings
}
func BufSize(n int) func(interface{}) error {
return func(s interface{}) error {
s.(*subSettings).buffer = n
return nil
}
}
func Name(name string) func(interface{}) error {
return func(s interface{}) error {
s.(*subSettings).name = name
return nil
}
}
type emitterSettings struct {
makeStateful bool
}
// Stateful is an Emitter option which makes the eventbus channel
// 'remember' last event sent, and when a new subscriber joins the
// bus, the remembered event is immediately sent to the subscription
// channel.
//
// This allows to provide state tracking for dynamic systems, and/or
// allows new subscribers to verify that there are Emitters on the channel
func Stateful(s interface{}) error {
s.(*emitterSettings).makeStateful = true
return nil
}
type Option func(*basicBus)
func WithMetricsTracer(metricsTracer MetricsTracer) Option {
return func(bus *basicBus) {
bus.metricsTracer = metricsTracer
bus.wildcard.metricsTracer = metricsTracer
}
}

View File

@@ -0,0 +1,58 @@
package peerstore
import (
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
)
// LatencyEWMASmoothing governs the decay of the EWMA (the speed
// at which it changes). This must be a normalized (0-1) value.
// 1 is 100% change, 0 is no change.
var LatencyEWMASmoothing = 0.1
type metrics struct {
mutex sync.RWMutex
latmap map[peer.ID]time.Duration
}
func NewMetrics() *metrics {
return &metrics{
latmap: make(map[peer.ID]time.Duration),
}
}
// RecordLatency records a new latency measurement
func (m *metrics) RecordLatency(p peer.ID, next time.Duration) {
nextf := float64(next)
s := LatencyEWMASmoothing
if s > 1 || s < 0 {
s = 0.1 // ignore the knob. it's broken. look, it jiggles.
}
m.mutex.Lock()
ewma, found := m.latmap[p]
ewmaf := float64(ewma)
if !found {
m.latmap[p] = next // when no data, just take it as the mean.
} else {
nextf = ((1.0 - s) * ewmaf) + (s * nextf)
m.latmap[p] = time.Duration(nextf)
}
m.mutex.Unlock()
}
// LatencyEWMA returns an exponentially-weighted moving avg.
// of all measurements of a peer's latency.
func (m *metrics) LatencyEWMA(p peer.ID) time.Duration {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.latmap[p]
}
func (m *metrics) RemovePeer(p peer.ID) {
m.mutex.Lock()
delete(m.latmap, p)
m.mutex.Unlock()
}

View File

@@ -0,0 +1,22 @@
package peerstore
import (
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
)
func PeerInfos(ps pstore.Peerstore, peers peer.IDSlice) []peer.AddrInfo {
pi := make([]peer.AddrInfo, len(peers))
for i, p := range peers {
pi[i] = ps.PeerInfo(p)
}
return pi
}
func PeerInfoIDs(pis []peer.AddrInfo) peer.IDSlice {
ps := make(peer.IDSlice, len(pis))
for i, pi := range pis {
ps[i] = pi.ID
}
return ps
}

View File

@@ -0,0 +1,530 @@
package pstoremem
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/record"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
)
var log = logging.Logger("peerstore")
type expiringAddr struct {
Addr ma.Multiaddr
TTL time.Duration
Expires time.Time
}
func (e *expiringAddr) ExpiredBy(t time.Time) bool {
return !t.Before(e.Expires)
}
type peerRecordState struct {
Envelope *record.Envelope
Seq uint64
}
type addrSegments [256]*addrSegment
type addrSegment struct {
sync.RWMutex
// Use pointers to save memory. Maps always leave some fraction of their
// space unused. storing the *values* directly in the map will
// drastically increase the space waste. In our case, by 6x.
addrs map[peer.ID]map[string]*expiringAddr
signedPeerRecords map[peer.ID]*peerRecordState
}
func (segments *addrSegments) get(p peer.ID) *addrSegment {
if len(p) == 0 { // it's not terribly useful to use an empty peer ID, but at least we should not panic
return segments[0]
}
return segments[uint8(p[len(p)-1])]
}
type clock interface {
Now() time.Time
}
type realclock struct{}
func (rc realclock) Now() time.Time {
return time.Now()
}
// memoryAddrBook manages addresses.
type memoryAddrBook struct {
segments addrSegments
refCount sync.WaitGroup
cancel func()
subManager *AddrSubManager
clock clock
}
var _ pstore.AddrBook = (*memoryAddrBook)(nil)
var _ pstore.CertifiedAddrBook = (*memoryAddrBook)(nil)
func NewAddrBook() *memoryAddrBook {
ctx, cancel := context.WithCancel(context.Background())
ab := &memoryAddrBook{
segments: func() (ret addrSegments) {
for i := range ret {
ret[i] = &addrSegment{
addrs: make(map[peer.ID]map[string]*expiringAddr),
signedPeerRecords: make(map[peer.ID]*peerRecordState)}
}
return ret
}(),
subManager: NewAddrSubManager(),
cancel: cancel,
clock: realclock{},
}
ab.refCount.Add(1)
go ab.background(ctx)
return ab
}
type AddrBookOption func(book *memoryAddrBook) error
func WithClock(clock clock) AddrBookOption {
return func(book *memoryAddrBook) error {
book.clock = clock
return nil
}
}
// background periodically schedules a gc
func (mab *memoryAddrBook) background(ctx context.Context) {
defer mab.refCount.Done()
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mab.gc()
case <-ctx.Done():
return
}
}
}
func (mab *memoryAddrBook) Close() error {
mab.cancel()
mab.refCount.Wait()
return nil
}
// gc garbage collects the in-memory address book.
func (mab *memoryAddrBook) gc() {
now := mab.clock.Now()
for _, s := range mab.segments {
s.Lock()
for p, amap := range s.addrs {
for k, addr := range amap {
if addr.ExpiredBy(now) {
delete(amap, k)
}
}
if len(amap) == 0 {
delete(s.addrs, p)
delete(s.signedPeerRecords, p)
}
}
s.Unlock()
}
}
func (mab *memoryAddrBook) PeersWithAddrs() peer.IDSlice {
// deduplicate, since the same peer could have both signed & unsigned addrs
set := make(map[peer.ID]struct{})
for _, s := range mab.segments {
s.RLock()
for pid, amap := range s.addrs {
if len(amap) > 0 {
set[pid] = struct{}{}
}
}
s.RUnlock()
}
peers := make(peer.IDSlice, 0, len(set))
for pid := range set {
peers = append(peers, pid)
}
return peers
}
// AddAddr calls AddAddrs(p, []ma.Multiaddr{addr}, ttl)
func (mab *memoryAddrBook) AddAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) {
mab.AddAddrs(p, []ma.Multiaddr{addr}, ttl)
}
// AddAddrs gives memoryAddrBook addresses to use, with a given ttl
// (time-to-live), after which the address is no longer valid.
// This function never reduces the TTL or expiration of an address.
func (mab *memoryAddrBook) AddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) {
// if we have a valid peer record, ignore unsigned addrs
// peerRec := mab.GetPeerRecord(p)
// if peerRec != nil {
// return
// }
mab.addAddrs(p, addrs, ttl)
}
// ConsumePeerRecord adds addresses from a signed peer.PeerRecord (contained in
// a record.Envelope), which will expire after the given TTL.
// See https://godoc.org/github.com/libp2p/go-libp2p/core/peerstore#CertifiedAddrBook for more details.
func (mab *memoryAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, ttl time.Duration) (bool, error) {
r, err := recordEnvelope.Record()
if err != nil {
return false, err
}
rec, ok := r.(*peer.PeerRecord)
if !ok {
return false, fmt.Errorf("unable to process envelope: not a PeerRecord")
}
if !rec.PeerID.MatchesPublicKey(recordEnvelope.PublicKey) {
return false, fmt.Errorf("signing key does not match PeerID in PeerRecord")
}
// ensure seq is greater than, or equal to, the last received
s := mab.segments.get(rec.PeerID)
s.Lock()
defer s.Unlock()
lastState, found := s.signedPeerRecords[rec.PeerID]
if found && lastState.Seq > rec.Seq {
return false, nil
}
s.signedPeerRecords[rec.PeerID] = &peerRecordState{
Envelope: recordEnvelope,
Seq: rec.Seq,
}
mab.addAddrsUnlocked(s, rec.PeerID, rec.Addrs, ttl, true)
return true, nil
}
func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) {
s := mab.segments.get(p)
s.Lock()
defer s.Unlock()
mab.addAddrsUnlocked(s, p, addrs, ttl, false)
}
func (mab *memoryAddrBook) addAddrsUnlocked(s *addrSegment, p peer.ID, addrs []ma.Multiaddr, ttl time.Duration, signed bool) {
// if ttl is zero, exit. nothing to do.
if ttl <= 0 {
return
}
amap, ok := s.addrs[p]
if !ok {
amap = make(map[string]*expiringAddr)
s.addrs[p] = amap
}
exp := mab.clock.Now().Add(ttl)
for _, addr := range addrs {
// Remove suffix of /p2p/peer-id from address
addr, addrPid := peer.SplitAddr(addr)
if addr == nil {
log.Warnw("Was passed nil multiaddr", "peer", p)
continue
}
if addrPid != "" && addrPid != p {
log.Warnf("Was passed p2p address with a different peerId. found: %s, expected: %s", addrPid, p)
continue
}
// find the highest TTL and Expiry time between
// existing records and function args
a, found := amap[string(addr.Bytes())] // won't allocate.
if !found {
// not found, announce it.
entry := &expiringAddr{Addr: addr, Expires: exp, TTL: ttl}
amap[string(addr.Bytes())] = entry
mab.subManager.BroadcastAddr(p, addr)
} else {
// update ttl & exp to whichever is greater between new and existing entry
if ttl > a.TTL {
a.TTL = ttl
}
if exp.After(a.Expires) {
a.Expires = exp
}
}
}
}
// SetAddr calls mgr.SetAddrs(p, addr, ttl)
func (mab *memoryAddrBook) SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) {
mab.SetAddrs(p, []ma.Multiaddr{addr}, ttl)
}
// SetAddrs sets the ttl on addresses. This clears any TTL there previously.
// This is used when we receive the best estimate of the validity of an address.
func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) {
s := mab.segments.get(p)
s.Lock()
defer s.Unlock()
amap, ok := s.addrs[p]
if !ok {
amap = make(map[string]*expiringAddr)
s.addrs[p] = amap
}
exp := mab.clock.Now().Add(ttl)
for _, addr := range addrs {
addr, addrPid := peer.SplitAddr(addr)
if addr == nil {
log.Warnw("was passed nil multiaddr", "peer", p)
continue
}
if addrPid != "" && addrPid != p {
log.Warnf("was passed p2p address with a different peerId, found: %s wanted: %s", addrPid, p)
continue
}
aBytes := addr.Bytes()
key := string(aBytes)
// re-set all of them for new ttl.
if ttl > 0 {
amap[key] = &expiringAddr{Addr: addr, Expires: exp, TTL: ttl}
mab.subManager.BroadcastAddr(p, addr)
} else {
delete(amap, key)
}
}
}
// UpdateAddrs updates the addresses associated with the given peer that have
// the given oldTTL to have the given newTTL.
func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) {
s := mab.segments.get(p)
s.Lock()
defer s.Unlock()
exp := mab.clock.Now().Add(newTTL)
amap, found := s.addrs[p]
if !found {
return
}
for k, a := range amap {
if oldTTL == a.TTL {
if newTTL == 0 {
delete(amap, k)
} else {
a.TTL = newTTL
a.Expires = exp
amap[k] = a
}
}
}
}
// Addrs returns all known (and valid) addresses for a given peer
func (mab *memoryAddrBook) Addrs(p peer.ID) []ma.Multiaddr {
s := mab.segments.get(p)
s.RLock()
defer s.RUnlock()
return validAddrs(mab.clock.Now(), s.addrs[p])
}
func validAddrs(now time.Time, amap map[string]*expiringAddr) []ma.Multiaddr {
good := make([]ma.Multiaddr, 0, len(amap))
if amap == nil {
return good
}
for _, m := range amap {
if !m.ExpiredBy(now) {
good = append(good, m.Addr)
}
}
return good
}
// GetPeerRecord returns a Envelope containing a PeerRecord for the
// given peer id, if one exists.
// Returns nil if no signed PeerRecord exists for the peer.
func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope {
s := mab.segments.get(p)
s.RLock()
defer s.RUnlock()
// although the signed record gets garbage collected when all addrs inside it are expired,
// we may be in between the expiration time and the GC interval
// so, we check to see if we have any valid signed addrs before returning the record
if len(validAddrs(mab.clock.Now(), s.addrs[p])) == 0 {
return nil
}
state := s.signedPeerRecords[p]
if state == nil {
return nil
}
return state.Envelope
}
// ClearAddrs removes all previously stored addresses
func (mab *memoryAddrBook) ClearAddrs(p peer.ID) {
s := mab.segments.get(p)
s.Lock()
defer s.Unlock()
delete(s.addrs, p)
delete(s.signedPeerRecords, p)
}
// AddrStream returns a channel on which all new addresses discovered for a
// given peer ID will be published.
func (mab *memoryAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multiaddr {
s := mab.segments.get(p)
s.RLock()
defer s.RUnlock()
baseaddrslice := s.addrs[p]
initial := make([]ma.Multiaddr, 0, len(baseaddrslice))
for _, a := range baseaddrslice {
initial = append(initial, a.Addr)
}
return mab.subManager.AddrStream(ctx, p, initial)
}
type addrSub struct {
pubch chan ma.Multiaddr
ctx context.Context
}
func (s *addrSub) pubAddr(a ma.Multiaddr) {
select {
case s.pubch <- a:
case <-s.ctx.Done():
}
}
// An abstracted, pub-sub manager for address streams. Extracted from
// memoryAddrBook in order to support additional implementations.
type AddrSubManager struct {
mu sync.RWMutex
subs map[peer.ID][]*addrSub
}
// NewAddrSubManager initializes an AddrSubManager.
func NewAddrSubManager() *AddrSubManager {
return &AddrSubManager{
subs: make(map[peer.ID][]*addrSub),
}
}
// Used internally by the address stream coroutine to remove a subscription
// from the manager.
func (mgr *AddrSubManager) removeSub(p peer.ID, s *addrSub) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
subs := mgr.subs[p]
if len(subs) == 1 {
if subs[0] != s {
return
}
delete(mgr.subs, p)
return
}
for i, v := range subs {
if v == s {
subs[i] = subs[len(subs)-1]
subs[len(subs)-1] = nil
mgr.subs[p] = subs[:len(subs)-1]
return
}
}
}
// BroadcastAddr broadcasts a new address to all subscribed streams.
func (mgr *AddrSubManager) BroadcastAddr(p peer.ID, addr ma.Multiaddr) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
if subs, ok := mgr.subs[p]; ok {
for _, sub := range subs {
sub.pubAddr(addr)
}
}
}
// AddrStream creates a new subscription for a given peer ID, pre-populating the
// channel with any addresses we might already have on file.
func (mgr *AddrSubManager) AddrStream(ctx context.Context, p peer.ID, initial []ma.Multiaddr) <-chan ma.Multiaddr {
sub := &addrSub{pubch: make(chan ma.Multiaddr), ctx: ctx}
out := make(chan ma.Multiaddr)
mgr.mu.Lock()
mgr.subs[p] = append(mgr.subs[p], sub)
mgr.mu.Unlock()
sort.Sort(addrList(initial))
go func(buffer []ma.Multiaddr) {
defer close(out)
sent := make(map[string]struct{}, len(buffer))
for _, a := range buffer {
sent[string(a.Bytes())] = struct{}{}
}
var outch chan ma.Multiaddr
var next ma.Multiaddr
if len(buffer) > 0 {
next = buffer[0]
buffer = buffer[1:]
outch = out
}
for {
select {
case outch <- next:
if len(buffer) > 0 {
next = buffer[0]
buffer = buffer[1:]
} else {
outch = nil
next = nil
}
case naddr := <-sub.pubch:
if _, ok := sent[string(naddr.Bytes())]; ok {
continue
}
sent[string(naddr.Bytes())] = struct{}{}
if next == nil {
next = naddr
outch = out
} else {
buffer = append(buffer, naddr)
}
case <-ctx.Done():
mgr.removeSub(p, sub)
return
}
}
}(initial)
return out
}

View File

@@ -0,0 +1,97 @@
package pstoremem
import (
"errors"
"sync"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
)
type memoryKeyBook struct {
sync.RWMutex // same lock. wont happen a ton.
pks map[peer.ID]ic.PubKey
sks map[peer.ID]ic.PrivKey
}
var _ pstore.KeyBook = (*memoryKeyBook)(nil)
func NewKeyBook() *memoryKeyBook {
return &memoryKeyBook{
pks: map[peer.ID]ic.PubKey{},
sks: map[peer.ID]ic.PrivKey{},
}
}
func (mkb *memoryKeyBook) PeersWithKeys() peer.IDSlice {
mkb.RLock()
ps := make(peer.IDSlice, 0, len(mkb.pks)+len(mkb.sks))
for p := range mkb.pks {
ps = append(ps, p)
}
for p := range mkb.sks {
if _, found := mkb.pks[p]; !found {
ps = append(ps, p)
}
}
mkb.RUnlock()
return ps
}
func (mkb *memoryKeyBook) PubKey(p peer.ID) ic.PubKey {
mkb.RLock()
pk := mkb.pks[p]
mkb.RUnlock()
if pk != nil {
return pk
}
pk, err := p.ExtractPublicKey()
if err == nil {
mkb.Lock()
mkb.pks[p] = pk
mkb.Unlock()
}
return pk
}
func (mkb *memoryKeyBook) AddPubKey(p peer.ID, pk ic.PubKey) error {
// check it's correct first
if !p.MatchesPublicKey(pk) {
return errors.New("ID does not match PublicKey")
}
mkb.Lock()
mkb.pks[p] = pk
mkb.Unlock()
return nil
}
func (mkb *memoryKeyBook) PrivKey(p peer.ID) ic.PrivKey {
mkb.RLock()
defer mkb.RUnlock()
return mkb.sks[p]
}
func (mkb *memoryKeyBook) AddPrivKey(p peer.ID, sk ic.PrivKey) error {
if sk == nil {
return errors.New("sk is nil (PrivKey)")
}
// check it's correct first
if !p.MatchesPrivateKey(sk) {
return errors.New("ID does not match PrivateKey")
}
mkb.Lock()
mkb.sks[p] = sk
mkb.Unlock()
return nil
}
func (mkb *memoryKeyBook) RemovePeer(p peer.ID) {
mkb.Lock()
delete(mkb.sks, p)
delete(mkb.pks, p)
mkb.Unlock()
}

View File

@@ -0,0 +1,54 @@
package pstoremem
import (
"sync"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
)
type memoryPeerMetadata struct {
// store other data, like versions
ds map[peer.ID]map[string]interface{}
dslock sync.RWMutex
}
var _ pstore.PeerMetadata = (*memoryPeerMetadata)(nil)
func NewPeerMetadata() *memoryPeerMetadata {
return &memoryPeerMetadata{
ds: make(map[peer.ID]map[string]interface{}),
}
}
func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error {
ps.dslock.Lock()
defer ps.dslock.Unlock()
m, ok := ps.ds[p]
if !ok {
m = make(map[string]interface{})
ps.ds[p] = m
}
m[key] = val
return nil
}
func (ps *memoryPeerMetadata) Get(p peer.ID, key string) (interface{}, error) {
ps.dslock.RLock()
defer ps.dslock.RUnlock()
m, ok := ps.ds[p]
if !ok {
return nil, pstore.ErrNotFound
}
val, ok := m[key]
if !ok {
return nil, pstore.ErrNotFound
}
return val, nil
}
func (ps *memoryPeerMetadata) RemovePeer(p peer.ID) {
ps.dslock.Lock()
delete(ps.ds, p)
ps.dslock.Unlock()
}

View File

@@ -0,0 +1,114 @@
package pstoremem
import (
"fmt"
"io"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
pstore "github.com/libp2p/go-libp2p/p2p/host/peerstore"
)
type pstoremem struct {
peerstore.Metrics
*memoryKeyBook
*memoryAddrBook
*memoryProtoBook
*memoryPeerMetadata
}
var _ peerstore.Peerstore = &pstoremem{}
type Option interface{}
// NewPeerstore creates an in-memory thread-safe collection of peers.
// It's the caller's responsibility to call RemovePeer to ensure
// that memory consumption of the peerstore doesn't grow unboundedly.
func NewPeerstore(opts ...Option) (ps *pstoremem, err error) {
ab := NewAddrBook()
defer func() {
if err != nil {
ab.Close()
}
}()
var protoBookOpts []ProtoBookOption
for _, opt := range opts {
switch o := opt.(type) {
case ProtoBookOption:
protoBookOpts = append(protoBookOpts, o)
case AddrBookOption:
o(ab)
default:
return nil, fmt.Errorf("unexpected peer store option: %v", o)
}
}
pb, err := NewProtoBook(protoBookOpts...)
if err != nil {
return nil, err
}
return &pstoremem{
Metrics: pstore.NewMetrics(),
memoryKeyBook: NewKeyBook(),
memoryAddrBook: ab,
memoryProtoBook: pb,
memoryPeerMetadata: NewPeerMetadata(),
}, nil
}
func (ps *pstoremem) Close() (err error) {
var errs []error
weakClose := func(name string, c interface{}) {
if cl, ok := c.(io.Closer); ok {
if err = cl.Close(); err != nil {
errs = append(errs, fmt.Errorf("%s error: %s", name, err))
}
}
}
weakClose("keybook", ps.memoryKeyBook)
weakClose("addressbook", ps.memoryAddrBook)
weakClose("protobook", ps.memoryProtoBook)
weakClose("peermetadata", ps.memoryPeerMetadata)
if len(errs) > 0 {
return fmt.Errorf("failed while closing peerstore; err(s): %q", errs)
}
return nil
}
func (ps *pstoremem) Peers() peer.IDSlice {
set := map[peer.ID]struct{}{}
for _, p := range ps.PeersWithKeys() {
set[p] = struct{}{}
}
for _, p := range ps.PeersWithAddrs() {
set[p] = struct{}{}
}
pps := make(peer.IDSlice, 0, len(set))
for p := range set {
pps = append(pps, p)
}
return pps
}
func (ps *pstoremem) PeerInfo(p peer.ID) peer.AddrInfo {
return peer.AddrInfo{
ID: p,
Addrs: ps.memoryAddrBook.Addrs(p),
}
}
// RemovePeer removes entries associated with a peer from:
// * the KeyBook
// * the ProtoBook
// * the PeerMetadata
// * the Metrics
// It DOES NOT remove the peer from the AddrBook.
func (ps *pstoremem) RemovePeer(p peer.ID) {
ps.memoryKeyBook.RemovePeer(p)
ps.memoryProtoBook.RemovePeer(p)
ps.memoryPeerMetadata.RemovePeer(p)
ps.Metrics.RemovePeer(p)
}

View File

@@ -0,0 +1,192 @@
package pstoremem
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
)
type protoSegment struct {
sync.RWMutex
protocols map[peer.ID]map[protocol.ID]struct{}
}
type protoSegments [256]*protoSegment
func (s *protoSegments) get(p peer.ID) *protoSegment {
return s[byte(p[len(p)-1])]
}
var errTooManyProtocols = errors.New("too many protocols")
type memoryProtoBook struct {
segments protoSegments
maxProtos int
lk sync.RWMutex
interned map[protocol.ID]protocol.ID
}
var _ pstore.ProtoBook = (*memoryProtoBook)(nil)
type ProtoBookOption func(book *memoryProtoBook) error
func WithMaxProtocols(num int) ProtoBookOption {
return func(pb *memoryProtoBook) error {
pb.maxProtos = num
return nil
}
}
func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) {
pb := &memoryProtoBook{
interned: make(map[protocol.ID]protocol.ID, 256),
segments: func() (ret protoSegments) {
for i := range ret {
ret[i] = &protoSegment{
protocols: make(map[peer.ID]map[protocol.ID]struct{}),
}
}
return ret
}(),
maxProtos: 1024,
}
for _, opt := range opts {
if err := opt(pb); err != nil {
return nil, err
}
}
return pb, nil
}
func (pb *memoryProtoBook) internProtocol(proto protocol.ID) protocol.ID {
// check if it is interned with the read lock
pb.lk.RLock()
interned, ok := pb.interned[proto]
pb.lk.RUnlock()
if ok {
return interned
}
// intern with the write lock
pb.lk.Lock()
defer pb.lk.Unlock()
// check again in case it got interned in between locks
interned, ok = pb.interned[proto]
if ok {
return interned
}
pb.interned[proto] = proto
return proto
}
func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error {
if len(protos) > pb.maxProtos {
return errTooManyProtocols
}
newprotos := make(map[protocol.ID]struct{}, len(protos))
for _, proto := range protos {
newprotos[pb.internProtocol(proto)] = struct{}{}
}
s := pb.segments.get(p)
s.Lock()
s.protocols[p] = newprotos
s.Unlock()
return nil
}
func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
protomap, ok := s.protocols[p]
if !ok {
protomap = make(map[protocol.ID]struct{})
s.protocols[p] = protomap
}
if len(protomap)+len(protos) > pb.maxProtos {
return errTooManyProtocols
}
for _, proto := range protos {
protomap[pb.internProtocol(proto)] = struct{}{}
}
return nil
}
func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
out := make([]protocol.ID, 0, len(s.protocols[p]))
for k := range s.protocols[p] {
out = append(out, k)
}
return out, nil
}
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
protomap, ok := s.protocols[p]
if !ok {
// nothing to remove.
return nil
}
for _, proto := range protos {
delete(protomap, pb.internProtocol(proto))
}
return nil
}
func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
out := make([]protocol.ID, 0, len(protos))
for _, proto := range protos {
if _, ok := s.protocols[p][proto]; ok {
out = append(out, proto)
}
}
return out, nil
}
func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
for _, proto := range protos {
if _, ok := s.protocols[p][proto]; ok {
return proto, nil
}
}
return "", nil
}
func (pb *memoryProtoBook) RemovePeer(p peer.ID) {
s := pb.segments.get(p)
s.Lock()
delete(s.protocols, p)
s.Unlock()
}

View File

@@ -0,0 +1,50 @@
package pstoremem
import (
"bytes"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
)
func isFDCostlyTransport(a ma.Multiaddr) bool {
return mafmt.TCP.Matches(a)
}
type addrList []ma.Multiaddr
func (al addrList) Len() int { return len(al) }
func (al addrList) Swap(i, j int) { al[i], al[j] = al[j], al[i] }
func (al addrList) Less(i, j int) bool {
a := al[i]
b := al[j]
// dial localhost addresses next, they should fail immediately
lba := manet.IsIPLoopback(a)
lbb := manet.IsIPLoopback(b)
if lba && !lbb {
return true
}
// dial utp and similar 'non-fd-consuming' addresses first
fda := isFDCostlyTransport(a)
fdb := isFDCostlyTransport(b)
if !fda {
return fdb
}
// if 'b' doesnt take a file descriptor
if !fdb {
return false
}
// if 'b' is loopback and both take file descriptors
if lbb {
return false
}
// for the rest, just sort by bytes
return bytes.Compare(a.Bytes(), b.Bytes()) > 0
}

View File

@@ -0,0 +1,132 @@
package pstoremanager
import (
"context"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
logging "github.com/ipfs/go-log/v2"
)
var log = logging.Logger("pstoremanager")
type Option func(*PeerstoreManager) error
// WithGracePeriod sets the grace period.
// If a peer doesn't reconnect during the grace period, its data is removed.
// Default: 1 minute.
func WithGracePeriod(p time.Duration) Option {
return func(m *PeerstoreManager) error {
m.gracePeriod = p
return nil
}
}
// WithCleanupInterval set the clean up interval.
// During a clean up run peers that disconnected before the grace period are removed.
// If unset, the interval is set to half the grace period.
func WithCleanupInterval(t time.Duration) Option {
return func(m *PeerstoreManager) error {
m.cleanupInterval = t
return nil
}
}
type PeerstoreManager struct {
pstore peerstore.Peerstore
eventBus event.Bus
cancel context.CancelFunc
refCount sync.WaitGroup
gracePeriod time.Duration
cleanupInterval time.Duration
}
func NewPeerstoreManager(pstore peerstore.Peerstore, eventBus event.Bus, opts ...Option) (*PeerstoreManager, error) {
m := &PeerstoreManager{
pstore: pstore,
gracePeriod: time.Minute,
eventBus: eventBus,
}
for _, opt := range opts {
if err := opt(m); err != nil {
return nil, err
}
}
if m.cleanupInterval == 0 {
m.cleanupInterval = m.gracePeriod / 2
}
return m, nil
}
func (m *PeerstoreManager) Start() {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
sub, err := m.eventBus.Subscribe(&event.EvtPeerConnectednessChanged{}, eventbus.Name("pstoremanager"))
if err != nil {
log.Warnf("subscription failed. Peerstore manager not activated. Error: %s", err)
return
}
m.refCount.Add(1)
go m.background(ctx, sub)
}
func (m *PeerstoreManager) background(ctx context.Context, sub event.Subscription) {
defer m.refCount.Done()
defer sub.Close()
disconnected := make(map[peer.ID]time.Time)
ticker := time.NewTicker(m.cleanupInterval)
defer ticker.Stop()
defer func() {
for p := range disconnected {
m.pstore.RemovePeer(p)
}
}()
for {
select {
case e, ok := <-sub.Out():
if !ok {
return
}
ev := e.(event.EvtPeerConnectednessChanged)
p := ev.Peer
switch ev.Connectedness {
case network.NotConnected:
if _, ok := disconnected[p]; !ok {
disconnected[p] = time.Now()
}
case network.Connected:
// If we reconnect to the peer before we've cleared the information, keep it.
delete(disconnected, p)
}
case <-ticker.C:
now := time.Now()
for p, disconnectTime := range disconnected {
if disconnectTime.Add(m.gracePeriod).Before(now) {
m.pstore.RemovePeer(p)
delete(disconnected, p)
}
}
case <-ctx.Done():
return
}
}
}
func (m *PeerstoreManager) Close() error {
if m.cancel != nil {
m.cancel()
}
m.refCount.Wait()
return nil
}

View File

@@ -0,0 +1,96 @@
package relaysvc
import (
"context"
"sync"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
)
type RelayManager struct {
host host.Host
mutex sync.Mutex
relay *relayv2.Relay
opts []relayv2.Option
refCount sync.WaitGroup
ctxCancel context.CancelFunc
}
func NewRelayManager(host host.Host, opts ...relayv2.Option) *RelayManager {
ctx, cancel := context.WithCancel(context.Background())
m := &RelayManager{
host: host,
opts: opts,
ctxCancel: cancel,
}
m.refCount.Add(1)
go m.background(ctx)
return m
}
func (m *RelayManager) background(ctx context.Context) {
defer m.refCount.Done()
defer func() {
m.mutex.Lock()
if m.relay != nil {
m.relay.Close()
}
m.mutex.Unlock()
}()
subReachability, _ := m.host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("relaysvc"))
defer subReachability.Close()
for {
select {
case <-ctx.Done():
return
case ev, ok := <-subReachability.Out():
if !ok {
return
}
if err := m.reachabilityChanged(ev.(event.EvtLocalReachabilityChanged).Reachability); err != nil {
return
}
}
}
}
func (m *RelayManager) reachabilityChanged(r network.Reachability) error {
switch r {
case network.ReachabilityPublic:
m.mutex.Lock()
defer m.mutex.Unlock()
// This could happen if two consecutive EvtLocalReachabilityChanged report the same reachability.
// This shouldn't happen, but it's safer to double-check.
if m.relay != nil {
return nil
}
relay, err := relayv2.New(m.host, m.opts...)
if err != nil {
return err
}
m.relay = relay
default:
m.mutex.Lock()
defer m.mutex.Unlock()
if m.relay != nil {
err := m.relay.Close()
m.relay = nil
return err
}
}
return nil
}
func (m *RelayManager) Close() error {
m.ctxCancel()
m.refCount.Wait()
return nil
}

View File

@@ -0,0 +1,626 @@
# The libp2p Network Resource Manager
This package contains the canonical implementation of the libp2p
Network Resource Manager interface.
The implementation is based on the concept of Resource Management
Scopes, whereby resource usage is constrained by a DAG of scopes,
accounting for multiple levels of resource constraints.
The Resource Manager doesn't prioritize resource requests at all, it simply
checks if the resource being requested is currently below the defined limits and
returns an error if the limit is reached. It has no notion of honest vs bad peers.
The Resource Manager does have a special notion of [allowlisted](#allowlisting-multiaddrs-to-mitigate-eclipse-attacks) multiaddrs that
have their own limits if the normal system limits are reached.
## Usage
The Resource Manager is intended to be used with go-libp2p. go-libp2p sets up a
resource manager with the default autoscaled limits if none is provided, but if
you want to configure things or if you want to enable metrics you'll use the
resource manager like so:
```go
// Start with the default scaling limits.
scalingLimits := rcmgr.DefaultLimits
// Add limits around included libp2p protocols
libp2p.SetDefaultServiceLimits(&scalingLimits)
// Turn the scaling limits into a concrete set of limits using `.AutoScale`. This
// scales the limits proportional to your system memory.
scaledDefaultLimits := scalingLimits.AutoScale()
// Tweak certain settings
cfg := rcmgr.PartialLimitConfig{
System: rcmgr.ResourceLimits{
// Allow unlimited outbound streams
StreamsOutbound: rcmgr.Unlimited,
},
// Everything else is default. The exact values will come from `scaledDefaultLimits` above.
}
// Create our limits by using our cfg and replacing the default values with values from `scaledDefaultLimits`
limits := cfg.Build(scaledDefaultLimits)
// The resource manager expects a limiter, se we create one from our limits.
limiter := rcmgr.NewFixedLimiter(limits)
// Metrics are enabled by default. If you want to disable metrics, use the
// WithMetricsDisabled option
// Initialize the resource manager
rm, err := rcmgr.NewResourceManager(limiter, rcmgr.WithMetricsDisabled())
if err != nil {
panic(err)
}
// Create a libp2p host
host, err := libp2p.New(libp2p.ResourceManager(rm))
```
### Saving the limits config
The easiest way to save the defined limits is to serialize the `PartialLimitConfig`
type as JSON.
```go
noisyNeighbor, _ := peer.Decode("QmVvtzcZgCkMnSFf2dnrBPXrWuNFWNM9J3MpZQCvWPuVZf")
cfg := rcmgr.PartialLimitConfig{
System: &rcmgr.ResourceLimits{
// Allow unlimited outbound streams
StreamsOutbound: rcmgr.Unlimited,
},
Peer: map[peer.ID]rcmgr.ResourceLimits{
noisyNeighbor: {
// No inbound connections from this peer
ConnsInbound: rcmgr.BlockAllLimit,
// But let me open connections to them
Conns: rcmgr.DefaultLimit,
ConnsOutbound: rcmgr.DefaultLimit,
// No inbound streams from this peer
StreamsInbound: rcmgr.BlockAllLimit,
// And let me open unlimited (by me) outbound streams (the peer may have their own limits on me)
StreamsOutbound: rcmgr.Unlimited,
},
},
}
jsonBytes, _ := json.Marshal(&cfg)
// string(jsonBytes)
// {
// "System": {
// "StreamsOutbound": "unlimited"
// },
// "Peer": {
// "QmVvtzcZgCkMnSFf2dnrBPXrWuNFWNM9J3MpZQCvWPuVZf": {
// "StreamsInbound": "blockAll",
// "StreamsOutbound": "unlimited",
// "ConnsInbound": "blockAll"
// }
// }
// }
```
This will omit defaults from the JSON output. It will also serialize the
blockAll, and unlimited values explicitly.
The `Memory` field is serialized as a string to workaround the JSON limitation
of 32 bit integers (`Memory` is an int64).
## Basic Resources
### Memory
Perhaps the most fundamental resource is memory, and in particular
buffers used for network operations. The system must provide an
interface for components to reserve memory that accounts for buffers
(and possibly other live objects), which is scoped within the component.
Before a new buffer is allocated, the component should try a memory
reservation, which can fail if the resource limit is exceeded. It is
then up to the component to react to the error condition, depending on
the situation. For example, a muxer failing to grow a buffer in
response to a window change should simply retain the old buffer and
operate at perhaps degraded performance.
### File Descriptors
File descriptors are an important resource that uses memory (and
computational time) at the system level. They are also a scarce
resource, as typically (unless the user explicitly intervenes) they
are constrained by the system. Exhaustion of file descriptors may
render the application incapable of operating (e.g., because it is
unable to open a file). This is important for libp2p because most
operating systems represent sockets as file descriptors.
### Connections
Connections are a higher-level concept endemic to libp2p; in order to
communicate with another peer, a connection must first be
established. Connections are an important resource in libp2p, as they
consume memory, goroutines, and possibly file descriptors.
We distinguish between inbound and outbound connections, as the former
are initiated by remote peers and consume resources in response to
network events and thus need to be tightly controlled in order to
protect the application from overload or attack. Outbound
connections are typically initiated by the application's volition and
don't need to be controlled as tightly. However, outbound connections
still consume resources and may be initiated in response to network
events because of (potentially faulty) application logic, so they
still need to be constrained.
### Streams
Streams are the fundamental object of interaction in libp2p; all
protocol interactions happen through a stream that goes over some
connection. Streams are a fundamental resource in libp2p, as they
consume memory and goroutines at all levels of the stack.
Streams always belong to a peer, specify a protocol and they may
belong to some service in the system. Hence, this suggests that apart
from global limits, we can constrain stream usage at finer
granularity, at the protocol and service level.
Once again, we disinguish between inbound and outbound streams.
Inbound streams are initiated by remote peers and consume resources in
response to network events; controlling inbound stream usage is again
paramount for protecting the system from overload or attack.
Outbound streams are normally initiated by the application or some
service in the system in order to effect some protocol
interaction. However, they can also be initiated in response to
network events because of application or service logic, so we still
need to constrain them.
## Resource Scopes
The Resource Manager is based on the concept of resource
scopes. Resource Scopes account for resource usage that is temporally
delimited for the span of the scope. Resource Scopes conceptually
form a DAG, providing us with a mechanism to enforce multiresolution
resource accounting. Downstream resource usage is aggregated at scopes
higher up the graph.
The following diagram depicts the canonical scope graph:
```
System
+------------> Transient.............+................+
| . .
+------------> Service------------- . ----------+ .
| . | .
+-------------> Protocol----------- . ----------+ .
| . | .
+-------------->* Peer \/ | .
+------------> Connection | .
| \/ \/
+---------------------------> Stream
```
### The System Scope
The system scope is the top level scope that accounts for global
resource usage at all levels of the system. This scope nests and
constrains all other scopes and institutes global hard limits.
### The Transient Scope
The transient scope accounts for resources that are in the process of
full establishment. For instance, a new connection prior to the
handshake does not belong to any peer, but it still needs to be
constrained as this opens an avenue for attacks in transient resource
usage. Similarly, a stream that has not negotiated a protocol yet is
constrained by the transient scope.
The transient scope effectively represents a DMZ (DeMilitarized Zone),
where resource usage can be accounted for connections and streams that
are not fully established.
### The Allowlist System Scope
Same as the normal system scope above, but is used if the normal system scope is
already at its limits and the resource is from an allowlisted peer. See
[Allowlisting multiaddrs to mitigate eclipse
attacks](#allowlisting-multiaddrs-to-mitigate-eclipse-attacks) see for more
information.
### The Allowlist Transient Scope
Same as the normal transient scope above, but is used if the normal transient
scope is already at its limits and the resource is from an allowlisted peer. See
[Allowlisting multiaddrs to mitigate eclipse
attacks](#allowlisting-multiaddrs-to-mitigate-eclipse-attacks) see for more
information.
### Service Scopes
The system is typically organized across services, which may be
ambient and provide basic functionality to the system (e.g. identify,
autonat, relay, etc). Alternatively, services may be explicitly
instantiated by the application, and provide core components of its
functionality (e.g. pubsub, the DHT, etc).
Services are logical groupings of streams that implement protocol flow
and may additionally consume resources such as memory. Services
typically have at least one stream handler, so they are subject to
inbound stream creation and resource usage in response to network
events. As such, the system explicitly models them allowing for
isolated resource usage that can be tuned by the user.
### Protocol Scopes
Protocol Scopes account for resources at the protocol level. They are
an intermediate resource scope which can constrain streams which may
not have a service associated or for resource control within a
service. It also provides an opportunity for system operators to
explicitly restrict specific protocols.
For instance, a service that is not aware of the resource manager and
has not been ported to mark its streams, may still gain limits
transparently without any programmer intervention. Furthermore, the
protocol scope can constrain resource usage for services that
implement multiple protocols for the sake of backwards
compatibility. A tighter limit in some older protocol can protect the
application from resource consumption caused by legacy clients or
potential attacks.
For a concrete example, consider pubsub with the gossipsub router: the
service also understands the floodsub protocol for backwards
compatibility and support for unsophisticated clients that are lagging
in the implementation effort. By specifying a lower limit for the
floodsub protocol, we can can constrain the service level for legacy
clients using an inefficient protocol.
### Peer Scopes
The peer scope accounts for resource usage by an individual peer. This
constrains connections and streams and limits the blast radius of
resource consumption by a single remote peer.
This ensures that no single peer can use more resources than allowed
by the peer limits. Every peer has a default limit, but the programmer
may raise (or lower) limits for specific peers.
### Connection Scopes
The connection scope is delimited to the duration of a connection and
constrains resource usage by a single connection. The scope is a leaf
in the DAG, with a span that begins when a connection is established
and ends when the connection is closed. Its resources are aggregated
to the resource usage of a peer.
### Stream Scopes
The stream scope is delimited to the duration of a stream, and
constrains resource usage by a single stream. This scope is also a
leaf in the DAG, with span that begins when a stream is created and
ends when the stream is closed. Its resources are aggregated to the
resource usage of a peer, and constrained by a service and protocol
scope.
### User Transaction Scopes
User transaction scopes can be created as a child of any extant
resource scope, and provide the programmer with a delimited scope for
easy resource accounting. Transactions may form a tree that is rooted
to some canonical scope in the scope DAG.
For instance, a programmer may create a transaction scope within a
service that accounts for some control flow delimited resource
usage. Similarly, a programmer may create a transaction scope for some
interaction within a stream, e.g. a Request/Response interaction that
uses a buffer.
## Limits
Each resource scope has an associated limit object, which designates
limits for all [basic resources](#basic-resources). The limit is checked every time some
resource is reserved and provides the system with an opportunity to
constrain resource usage.
There are separate limits for each class of scope, allowing for
multiresolution and aggregate resource accounting. As such, we have
limits for the system and transient scopes, default and specific
limits for services, protocols, and peers, and limits for connections
and streams.
### Scaling Limits
When building software that is supposed to run on many different kind of machines,
with various memory and CPU configurations, it is desirable to have limits that
scale with the size of the machine.
This is done using the `ScalingLimitConfig`. For every scope, this configuration
struct defines the absolutely bare minimum limits, and an (optional) increase of
these limits, which will be applied on nodes that have sufficient memory.
A `ScalingLimitConfig` can be converted into a `ConcreteLimitConfig` (which can then be
used to initialize a fixed limiter with `NewFixedLimiter`) by calling the `Scale` method.
The `Scale` method takes two parameters: the amount of memory and the number of file
descriptors that an application is willing to dedicate to libp2p.
These amounts will differ between use cases. A blockchain node running on a dedicated
server might have a lot of memory, and dedicate 1/4 of that memory to libp2p. On the
other end of the spectrum, a desktop companion application running as a background
task on a consumer laptop will probably dedicate significantly less than 1/4 of its system
memory to libp2p.
For convenience, the `ScalingLimitConfig` also provides an `AutoScale` method,
which determines the amount of memory and file descriptors available on the
system, and dedicates up to 1/8 of the memory and 1/2 of the file descriptors to
libp2p.
For example, one might set:
```go
var scalingLimits = ScalingLimitConfig{
SystemBaseLimit: BaseLimit{
ConnsInbound: 64,
ConnsOutbound: 128,
Conns: 128,
StreamsInbound: 512,
StreamsOutbound: 1024,
Streams: 1024,
Memory: 128 << 20,
FD: 256,
},
SystemLimitIncrease: BaseLimitIncrease{
ConnsInbound: 32,
ConnsOutbound: 64,
Conns: 64,
StreamsInbound: 256,
StreamsOutbound: 512,
Streams: 512,
Memory: 256 << 20,
FDFraction: 1,
},
}
```
The base limit (`SystemBaseLimit`) here is the minimum configuration that any
node will have, no matter how little memory it possesses. For every GB of memory
passed into the `Scale` method, an increase of (`SystemLimitIncrease`) is added.
For Example, calling `Scale` with 4 GB of memory will result in a limit of 384 for
`Conns` (128 + 4*64).
The `FDFraction` defines how many of the file descriptors are allocated to this
scope. In the example above, when called with a file descriptor value of 1000,
this would result in a limit of 1000 (1000 * 1) file descriptors for the system
scope. See `TestReadmeExample` in `limit_test.go`.
Note that we only showed the configuration for the system scope here, equivalent
configuration options apply to all other scopes as well.
### Default limits
By default the resource manager ships with some reasonable scaling limits and
makes a reasonable guess at how much system memory you want to dedicate to the
go-libp2p process. For the default definitions see [`DefaultLimits` and
`ScalingLimitConfig.AutoScale()`](./limit_defaults.go).
### Tweaking Defaults
If the defaults seem mostly okay, but you want to adjust one facet you can
simply copy the default struct object and update the field you want to change. You can
apply changes to a `BaseLimit`, `BaseLimitIncrease`, and `ConcreteLimitConfig` with
`.Apply`.
Example
```
// An example on how to tweak the default limits
tweakedDefaults := DefaultLimits
tweakedDefaults.ProtocolBaseLimit.Streams = 1024
tweakedDefaults.ProtocolBaseLimit.StreamsInbound = 512
tweakedDefaults.ProtocolBaseLimit.StreamsOutbound = 512
```
### How to tune your limits
Once you've set your limits and monitoring (see [Monitoring](#monitoring) below)
you can now tune your limits better. The `rcmgr_blocked_resources` metric will
tell you what was blocked and for what scope. If you see a steady stream of
these blocked requests it means your resource limits are too low for your usage.
If you see a rare sudden spike, this is okay and it means the resource manager
protected you from some anomaly.
### How to disable limits
Sometimes disabling all limits is useful when you want to see how much
resources you use during normal operation. You can then use this information to
define your initial limits. Disable the limits by using `InfiniteLimits`.
### Debug "resource limit exceeded" errors
These errors occur whenever a limit is hit. For example, you'll get this error if
you are at your limit for the number of streams you can have, and you try to
open one more.
Example Log:
```
2022-08-12T15:49:35.459-0700 DEBUG rcmgr go-libp2p-resource-manager@v0.5.3/scope.go:541 blocked connection from constraining edge {"scope": "conn-19667", "edge": "system", "direction": "Inbound", "usefd": false, "current": 100, "attempted": 1, "limit": 100, "stat": {"NumStreamsInbound":28,"NumStreamsOutbound":66,"NumConnsInbound":37,"NumConnsOutbound":63,"NumFD":33,"Memory":8687616}, "error": "system: cannot reserve connection: resource limit exceeded"}
```
The log line above is an example log line that gets emitted if you enable debug
logging in the resource manager. You can do this by setting the environment
variable `GOLOG_LOG_LEVEL="rcmgr=debug"`. By default only the error is
returned to the caller, and nothing is logged by the resource manager itself.
The log line message (and returned error) will tell you which resource limit was
hit (connection in the log above) and what blocked it (in this case it was the
system scope that blocked it). The log will also include some more information
about the current usage of the resources. In the example log above, there is a
limit of 100 connections, and you can see that we have 37 inbound connections
and 63 outbound connections. We've reached the limit and the resource manager
will block any further connections.
The next step in debugging is seeing if this is a recurring problem or just a
transient error. If it's a transient error it's okay to ignore it since the
resource manager was doing its job in keeping resource usage under the limit. If
it's recurring then you should understand what's causing you to hit these limits
and either refactor your application or raise the limits.
To check if it's a recurring problem you can count the number of times you've
seen the `"resource limit exceeded"` error over time. You can also check the
`rcmgr_blocked_resources` metric to see how many times the resource manager has
blocked a resource over time.
![Example graph of blocked resources over time](https://bafkreibul6qipnax5s42abv3jc6bolhd7pju3zbl4rcvdaklmk52f6cznu.ipfs.w3s.link/)
If the resource is blocked by a protocol-level scope, take a look at the various
resource usages in the metrics. For example, if you run into a new stream being blocked,
you can check the
`rcmgr_streams` metric and the "Streams by protocol" graph in the Grafana
dashboard (assuming you've set that up or something similar  see
[Monitoring](#monitoring)) to understand the usage pattern of that specific
protocol. This can help answer questions such as: "Am I constantly around my
limit?", "Does it make sense to raise my limit?", "Are there any patterns around
hitting this limit?", and "should I refactor my protocol implementation?"
## Monitoring
Once you have limits set, you'll want to monitor to see if you're running into
your limits often. This could be a sign that you need to raise your limits
(your process is more intensive than you originally thought) or that you need
to fix something in your application (surely you don't need over 1000 streams?).
There are Prometheus metrics that can be hooked up to the resource manager. See
`obs/stats_test.go` for an example on how to enable this, and `DefaultViews` in
`stats.go` for recommended views. These metrics can be hooked up to Prometheus
or any other platform that can scrape a prometheus endpoint.
There is also an included Grafana dashboard to help kickstart your
observability into the resource manager. Find more information about it at
[here](./../../../dashboards/resource-manager/README.md).
## Allowlisting multiaddrs to mitigate eclipse attacks
If you have a set of trusted peers and IP addresses, you can use the resource
manager's [Allowlist](./docs/allowlist.md) to protect yourself from eclipse
attacks. The set of peers in the allowlist will have their own limits in case
the normal limits are reached. This means you will always be able to connect to
these trusted peers even if you've already reached your system limits.
Look at `WithAllowlistedMultiaddrs` and its example in the GoDoc to learn more.
## ConnManager vs Resource Manager
go-libp2p already includes a [connection
manager](https://pkg.go.dev/github.com/libp2p/go-libp2p/core/connmgr#ConnManager),
so what's the difference between the `ConnManager` and the `ResourceManager`?
ConnManager:
1. Configured with a low and high watermark number of connections.
2. Attempts to maintain the number of connections between the low and high
markers.
3. Connections can be given metadata and weight (e.g. a hole punched
connection is more valuable than a connection to a publicly addressable
endpoint since it took more effort to make the hole punched connection).
4. The ConnManager will trim connections once the high watermark is reached. and
trim down to the low watermark.
5. Won't block adding another connection above the high watermark, but will
trigger the trim mentioned above.
6. Can trim and prioritize connections with custom logic.
7. No concept of scopes (like the resource manager).
Resource Manager:
1. Configured with limits on the number of outgoing and incoming connections at
different [resource scopes](#resource-scopes).
2. Will block adding any more connections if any of the scope-specific limits would be exceeded.
The natural question when comparing these two managers is "how do the watermarks
and limits interact with each other?". The short answer is that they don't know
about each other. This can lead to some surprising subtleties, such as the
trimming never happening because the resource manager's limit is lower than the
high watermark. This is confusing, and we'd like to fix it. The issue is
captured in [go-libp2p#1640](https://github.com/libp2p/go-libp2p/issues/1640).
When configuring the resource manager and connection manager, you should set the
limits in the resource manager as your hard limits that you would never want to
go over, and set the low/high watermarks as the range at which your application
works best.
## Examples
Here we consider some concrete examples that can ellucidate the abstract
design as described so far.
### Stream Lifetime
Let's consider a stream and the limits that apply to it.
When the stream scope is first opened, it is created by calling
`ResourceManager.OpenStream`.
Initially the stream is constrained by:
- the system scope, where global hard limits apply.
- the transient scope, where unnegotiated streams live.
- the peer scope, where the limits for the peer at the other end of the stream
apply.
Once the protocol has been negotiated, the protocol is set by calling
`StreamManagementScope.SetProtocol`. The constraint from the
transient scope is removed and the stream is now constrained by the
protocol instead.
More specifically, the following constraints apply:
- the system scope, where global hard limits apply.
- the peer scope, where the limits for the peer at the other end of the stream
apply.
- the protocol scope, where the limits of the specific protocol used apply.
The existence of the protocol limit allows us to implicitly constrain
streams for services that have not been ported to the resource manager
yet. Once the programmer attaches a stream to a service by calling
`StreamScope.SetService`, the stream resources are aggregated and constrained
by the service scope in addition to its protocol scope.
More specifically the following constraints apply:
- the system scope, where global hard limits apply.
- the peer scope, where the limits for the peer at the other end of the stream
apply.
- the service scope, where the limits of the specific service owning the stream apply.
- the protcol scope, where the limits of the specific protocol for the stream apply.
The resource transfer that happens in the `SetProtocol` and `SetService`
gives the opportunity to the resource manager to gate the streams. If
the transfer results in exceeding the scope limits, then a error
indicating "resource limit exceeded" is returned. The wrapped error
includes the name of the scope rejecting the resource acquisition to
aid understanding of applicable limits. Note that the (wrapped) error
implements `net.Error` and is marked as temporary, so that the
programmer can handle by backoff retry.
## Implementation Notes
- The package only exports a constructor for the resource manager and
basic types for defining limits. Internals are not exposed.
- Internally, there is a resources object that is embedded in every scope and
implements resource accounting.
- There is a single implementation of a generic resource scope, that
provides all necessary interface methods.
- There are concrete types for all canonical scopes, embedding a
pointer to a generic resource scope.
- Peer and Protocol scopes, which may be created in response to
network events, are periodically garbage collected.
## Design Considerations
- The Resource Manager must account for basic resource usage at all
levels of the stack, from the internals to application components
that use the network facilities of libp2p.
- Basic resources include memory, streams, connections, and file
descriptors. These account for both space and time used by
the stack, as each resource has a direct effect on the system
availability and performance.
- The design must support seamless integration for user applications,
which should reap the benefits of resource management without any
changes. That is, existing applications should be oblivious of the
resource manager and transparently obtain limits which protects it
from resource exhaustion and OOM conditions.
- At the same time, the design must support opt-in resource usage
accounting for applications that want to explicitly utilize the
facilities of the system to inform about and constrain their own
resource usage.
- The design must allow the user to set their own limits, which can be
static (fixed) or dynamic.

View File

@@ -0,0 +1,216 @@
package rcmgr
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type Allowlist struct {
mu sync.RWMutex
// a simple structure of lists of networks. There is probably a faster way
// to check if an IP address is in this network than iterating over this
// list, but this is good enough for small numbers of networks (<1_000).
// Analyze the benchmark before trying to optimize this.
// Any peer with these IPs are allowed
allowedNetworks []*net.IPNet
// Only the specified peers can use these IPs
allowedPeerByNetwork map[peer.ID][]*net.IPNet
}
// WithAllowlistedMultiaddrs sets the multiaddrs to be in the allowlist
func WithAllowlistedMultiaddrs(mas []multiaddr.Multiaddr) Option {
return func(rm *resourceManager) error {
for _, ma := range mas {
err := rm.allowlist.Add(ma)
if err != nil {
return err
}
}
return nil
}
}
func newAllowlist() Allowlist {
return Allowlist{
allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet),
}
}
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) {
var ipString string
var mask string
var allowedPeerStr string
var allowedPeer peer.ID
var isIPV4 bool
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
if c.Protocol().Code == multiaddr.P_IP4 || c.Protocol().Code == multiaddr.P_IP6 {
isIPV4 = c.Protocol().Code == multiaddr.P_IP4
ipString = c.Value()
}
if c.Protocol().Code == multiaddr.P_IPCIDR {
mask = c.Value()
}
if c.Protocol().Code == multiaddr.P_P2P {
allowedPeerStr = c.Value()
}
return ipString == "" || mask == "" || allowedPeerStr == ""
})
if ipString == "" {
return nil, allowedPeer, errors.New("missing ip address")
}
if allowedPeerStr != "" {
var err error
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err)
}
}
if mask == "" {
ip := net.ParseIP(ipString)
if ip == nil {
return nil, allowedPeer, errors.New("invalid ip address")
}
var mask net.IPMask
if isIPV4 {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
net := &net.IPNet{IP: ip, Mask: mask}
return net, allowedPeer, nil
}
_, ipnet, err := net.ParseCIDR(ipString + "/" + mask)
return ipnet, allowedPeer, err
}
// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be
// an ip address of the peer with or without a `/p2p` protocol.
// e.g. /ip4/1.2.3.4/p2p/QmFoo, /ip4/1.2.3.4, and /ip4/1.2.3.0/ipcidr/24 are valid.
// /p2p/QmFoo is not valid.
func (al *Allowlist) Add(ma multiaddr.Multiaddr) error {
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
al.mu.Lock()
defer al.mu.Unlock()
if allowedPeer != peer.ID("") {
// We have a peerID constraint
if al.allowedPeerByNetwork == nil {
al.allowedPeerByNetwork = make(map[peer.ID][]*net.IPNet)
}
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
} else {
al.allowedNetworks = append(al.allowedNetworks, ipnet)
}
return nil
}
func (al *Allowlist) Remove(ma multiaddr.Multiaddr) error {
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
al.mu.Lock()
defer al.mu.Unlock()
ipNetList := al.allowedNetworks
if allowedPeer != "" {
// We have a peerID constraint
ipNetList = al.allowedPeerByNetwork[allowedPeer]
}
if ipNetList == nil {
return nil
}
i := len(ipNetList)
for i > 0 {
i--
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
// swap remove
ipNetList[i] = ipNetList[len(ipNetList)-1]
ipNetList = ipNetList[:len(ipNetList)-1]
// We only remove one thing
break
}
}
if allowedPeer != "" {
al.allowedPeerByNetwork[allowedPeer] = ipNetList
} else {
al.allowedNetworks = ipNetList
}
return nil
}
func (al *Allowlist) Allowed(ma multiaddr.Multiaddr) bool {
ip, err := manet.ToIP(ma)
if err != nil {
return false
}
al.mu.RLock()
defer al.mu.RUnlock()
for _, network := range al.allowedNetworks {
if network.Contains(ip) {
return true
}
}
for _, allowedNetworks := range al.allowedPeerByNetwork {
for _, network := range allowedNetworks {
if network.Contains(ip) {
return true
}
}
}
return false
}
func (al *Allowlist) AllowedPeerAndMultiaddr(peerID peer.ID, ma multiaddr.Multiaddr) bool {
ip, err := manet.ToIP(ma)
if err != nil {
return false
}
al.mu.RLock()
defer al.mu.RUnlock()
for _, network := range al.allowedNetworks {
if network.Contains(ip) {
// We found a match that isn't constrained by a peerID
return true
}
}
if expectedNetworks, ok := al.allowedPeerByNetwork[peerID]; ok {
for _, expectedNetwork := range expectedNetworks {
if expectedNetwork.Contains(ip) {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,81 @@
package rcmgr
import (
"errors"
"github.com/libp2p/go-libp2p/core/network"
)
type ErrStreamOrConnLimitExceeded struct {
current, attempted, limit int
err error
}
func (e *ErrStreamOrConnLimitExceeded) Error() string { return e.err.Error() }
func (e *ErrStreamOrConnLimitExceeded) Unwrap() error { return e.err }
// edge may be "" if this is not an edge error
func logValuesStreamLimit(scope, edge string, dir network.Direction, stat network.ScopeStat, err error) []interface{} {
logValues := make([]interface{}, 0, 2*8)
logValues = append(logValues, "scope", scope)
if edge != "" {
logValues = append(logValues, "edge", edge)
}
logValues = append(logValues, "direction", dir)
var e *ErrStreamOrConnLimitExceeded
if errors.As(err, &e) {
logValues = append(logValues,
"current", e.current,
"attempted", e.attempted,
"limit", e.limit,
)
}
return append(logValues, "stat", stat, "error", err)
}
// edge may be "" if this is not an edge error
func logValuesConnLimit(scope, edge string, dir network.Direction, usefd bool, stat network.ScopeStat, err error) []interface{} {
logValues := make([]interface{}, 0, 2*9)
logValues = append(logValues, "scope", scope)
if edge != "" {
logValues = append(logValues, "edge", edge)
}
logValues = append(logValues, "direction", dir, "usefd", usefd)
var e *ErrStreamOrConnLimitExceeded
if errors.As(err, &e) {
logValues = append(logValues,
"current", e.current,
"attempted", e.attempted,
"limit", e.limit,
)
}
return append(logValues, "stat", stat, "error", err)
}
type ErrMemoryLimitExceeded struct {
current, attempted, limit int64
priority uint8
err error
}
func (e *ErrMemoryLimitExceeded) Error() string { return e.err.Error() }
func (e *ErrMemoryLimitExceeded) Unwrap() error { return e.err }
// edge may be "" if this is not an edge error
func logValuesMemoryLimit(scope, edge string, stat network.ScopeStat, err error) []interface{} {
logValues := make([]interface{}, 0, 2*8)
logValues = append(logValues, "scope", scope)
if edge != "" {
logValues = append(logValues, "edge", edge)
}
var e *ErrMemoryLimitExceeded
if errors.As(err, &e) {
logValues = append(logValues,
"current", e.current,
"attempted", e.attempted,
"priority", e.priority,
"limit", e.limit,
)
}
return append(logValues, "stat", stat, "error", err)
}

View File

@@ -0,0 +1,147 @@
package rcmgr
import (
"bytes"
"sort"
"strings"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)
// ResourceScopeLimiter is a trait interface that allows you to access scope limits.
type ResourceScopeLimiter interface {
Limit() Limit
SetLimit(Limit)
}
var _ ResourceScopeLimiter = (*resourceScope)(nil)
// ResourceManagerStat is a trait that allows you to access resource manager state.
type ResourceManagerState interface {
ListServices() []string
ListProtocols() []protocol.ID
ListPeers() []peer.ID
Stat() ResourceManagerStat
}
type ResourceManagerStat struct {
System network.ScopeStat
Transient network.ScopeStat
Services map[string]network.ScopeStat
Protocols map[protocol.ID]network.ScopeStat
Peers map[peer.ID]network.ScopeStat
}
var _ ResourceManagerState = (*resourceManager)(nil)
func (s *resourceScope) Limit() Limit {
s.Lock()
defer s.Unlock()
return s.rc.limit
}
func (s *resourceScope) SetLimit(limit Limit) {
s.Lock()
defer s.Unlock()
s.rc.limit = limit
}
func (s *protocolScope) SetLimit(limit Limit) {
s.rcmgr.setStickyProtocol(s.proto)
s.resourceScope.SetLimit(limit)
}
func (s *peerScope) SetLimit(limit Limit) {
s.rcmgr.setStickyPeer(s.peer)
s.resourceScope.SetLimit(limit)
}
func (r *resourceManager) ListServices() []string {
r.mx.Lock()
defer r.mx.Unlock()
result := make([]string, 0, len(r.svc))
for svc := range r.svc {
result = append(result, svc)
}
sort.Slice(result, func(i, j int) bool {
return strings.Compare(result[i], result[j]) < 0
})
return result
}
func (r *resourceManager) ListProtocols() []protocol.ID {
r.mx.Lock()
defer r.mx.Unlock()
result := make([]protocol.ID, 0, len(r.proto))
for p := range r.proto {
result = append(result, p)
}
sort.Slice(result, func(i, j int) bool {
return result[i] < result[j]
})
return result
}
func (r *resourceManager) ListPeers() []peer.ID {
r.mx.Lock()
defer r.mx.Unlock()
result := make([]peer.ID, 0, len(r.peer))
for p := range r.peer {
result = append(result, p)
}
sort.Slice(result, func(i, j int) bool {
return bytes.Compare([]byte(result[i]), []byte(result[j])) < 0
})
return result
}
func (r *resourceManager) Stat() (result ResourceManagerStat) {
r.mx.Lock()
svcs := make([]*serviceScope, 0, len(r.svc))
for _, svc := range r.svc {
svcs = append(svcs, svc)
}
protos := make([]*protocolScope, 0, len(r.proto))
for _, proto := range r.proto {
protos = append(protos, proto)
}
peers := make([]*peerScope, 0, len(r.peer))
for _, peer := range r.peer {
peers = append(peers, peer)
}
r.mx.Unlock()
// Note: there is no global lock, so the system is updating while we are dumping its state...
// as such stats might not exactly add up to the system level; we take the system stat
// last nonetheless so that this is the most up-to-date snapshot
result.Peers = make(map[peer.ID]network.ScopeStat, len(peers))
for _, peer := range peers {
result.Peers[peer.peer] = peer.Stat()
}
result.Protocols = make(map[protocol.ID]network.ScopeStat, len(protos))
for _, proto := range protos {
result.Protocols[proto.proto] = proto.Stat()
}
result.Services = make(map[string]network.ScopeStat, len(svcs))
for _, svc := range svcs {
result.Services[svc.service] = svc.Stat()
}
result.Transient = r.transient.Stat()
result.System = r.system.Stat()
return result
}

View File

@@ -0,0 +1,297 @@
/*
Package rcmgr is the resource manager for go-libp2p. This allows you to track
resources being used throughout your go-libp2p process. As well as making sure
that the process doesn't use more resources than what you define as your
limits. The resource manager only knows about things it is told about, so it's
the responsibility of the user of this library (either go-libp2p or a go-libp2p
user) to make sure they check with the resource manager before actually
allocating the resource.
*/
package rcmgr
import (
"encoding/json"
"io"
"math"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)
// Limit is an object that specifies basic resource limits.
type Limit interface {
// GetMemoryLimit returns the (current) memory limit.
GetMemoryLimit() int64
// GetStreamLimit returns the stream limit, for inbound or outbound streams.
GetStreamLimit(network.Direction) int
// GetStreamTotalLimit returns the total stream limit
GetStreamTotalLimit() int
// GetConnLimit returns the connection limit, for inbound or outbound connections.
GetConnLimit(network.Direction) int
// GetConnTotalLimit returns the total connection limit
GetConnTotalLimit() int
// GetFDLimit returns the file descriptor limit.
GetFDLimit() int
}
// Limiter is the interface for providing limits to the resource manager.
type Limiter interface {
GetSystemLimits() Limit
GetTransientLimits() Limit
GetAllowlistedSystemLimits() Limit
GetAllowlistedTransientLimits() Limit
GetServiceLimits(svc string) Limit
GetServicePeerLimits(svc string) Limit
GetProtocolLimits(proto protocol.ID) Limit
GetProtocolPeerLimits(proto protocol.ID) Limit
GetPeerLimits(p peer.ID) Limit
GetStreamLimits(p peer.ID) Limit
GetConnLimits() Limit
}
// NewDefaultLimiterFromJSON creates a new limiter by parsing a json configuration,
// using the default limits for fallback.
func NewDefaultLimiterFromJSON(in io.Reader) (Limiter, error) {
return NewLimiterFromJSON(in, DefaultLimits.AutoScale())
}
// NewLimiterFromJSON creates a new limiter by parsing a json configuration.
func NewLimiterFromJSON(in io.Reader, defaults ConcreteLimitConfig) (Limiter, error) {
cfg, err := readLimiterConfigFromJSON(in, defaults)
if err != nil {
return nil, err
}
return &fixedLimiter{cfg}, nil
}
func readLimiterConfigFromJSON(in io.Reader, defaults ConcreteLimitConfig) (ConcreteLimitConfig, error) {
var cfg PartialLimitConfig
if err := json.NewDecoder(in).Decode(&cfg); err != nil {
return ConcreteLimitConfig{}, err
}
return cfg.Build(defaults), nil
}
// fixedLimiter is a limiter with fixed limits.
type fixedLimiter struct {
ConcreteLimitConfig
}
var _ Limiter = (*fixedLimiter)(nil)
func NewFixedLimiter(conf ConcreteLimitConfig) Limiter {
log.Debugw("initializing new limiter with config", "limits", conf)
return &fixedLimiter{conf}
}
// BaseLimit is a mixin type for basic resource limits.
type BaseLimit struct {
Streams int `json:",omitempty"`
StreamsInbound int `json:",omitempty"`
StreamsOutbound int `json:",omitempty"`
Conns int `json:",omitempty"`
ConnsInbound int `json:",omitempty"`
ConnsOutbound int `json:",omitempty"`
FD int `json:",omitempty"`
Memory int64 `json:",omitempty"`
}
func valueOrBlockAll(n int) LimitVal {
if n == 0 {
return BlockAllLimit
} else if n == math.MaxInt {
return Unlimited
}
return LimitVal(n)
}
func valueOrBlockAll64(n int64) LimitVal64 {
if n == 0 {
return BlockAllLimit64
} else if n == math.MaxInt {
return Unlimited64
}
return LimitVal64(n)
}
// ToResourceLimits converts the BaseLimit to a ResourceLimits
func (l BaseLimit) ToResourceLimits() ResourceLimits {
return ResourceLimits{
Streams: valueOrBlockAll(l.Streams),
StreamsInbound: valueOrBlockAll(l.StreamsInbound),
StreamsOutbound: valueOrBlockAll(l.StreamsOutbound),
Conns: valueOrBlockAll(l.Conns),
ConnsInbound: valueOrBlockAll(l.ConnsInbound),
ConnsOutbound: valueOrBlockAll(l.ConnsOutbound),
FD: valueOrBlockAll(l.FD),
Memory: valueOrBlockAll64(l.Memory),
}
}
// Apply overwrites all zero-valued limits with the values of l2
// Must not use a pointer receiver.
func (l *BaseLimit) Apply(l2 BaseLimit) {
if l.Streams == 0 {
l.Streams = l2.Streams
}
if l.StreamsInbound == 0 {
l.StreamsInbound = l2.StreamsInbound
}
if l.StreamsOutbound == 0 {
l.StreamsOutbound = l2.StreamsOutbound
}
if l.Conns == 0 {
l.Conns = l2.Conns
}
if l.ConnsInbound == 0 {
l.ConnsInbound = l2.ConnsInbound
}
if l.ConnsOutbound == 0 {
l.ConnsOutbound = l2.ConnsOutbound
}
if l.Memory == 0 {
l.Memory = l2.Memory
}
if l.FD == 0 {
l.FD = l2.FD
}
}
// BaseLimitIncrease is the increase per GiB of allowed memory.
type BaseLimitIncrease struct {
Streams int `json:",omitempty"`
StreamsInbound int `json:",omitempty"`
StreamsOutbound int `json:",omitempty"`
Conns int `json:",omitempty"`
ConnsInbound int `json:",omitempty"`
ConnsOutbound int `json:",omitempty"`
// Memory is in bytes. Values over 1>>30 (1GiB) don't make sense.
Memory int64 `json:",omitempty"`
// FDFraction is expected to be >= 0 and <= 1.
FDFraction float64 `json:",omitempty"`
}
// Apply overwrites all zero-valued limits with the values of l2
// Must not use a pointer receiver.
func (l *BaseLimitIncrease) Apply(l2 BaseLimitIncrease) {
if l.Streams == 0 {
l.Streams = l2.Streams
}
if l.StreamsInbound == 0 {
l.StreamsInbound = l2.StreamsInbound
}
if l.StreamsOutbound == 0 {
l.StreamsOutbound = l2.StreamsOutbound
}
if l.Conns == 0 {
l.Conns = l2.Conns
}
if l.ConnsInbound == 0 {
l.ConnsInbound = l2.ConnsInbound
}
if l.ConnsOutbound == 0 {
l.ConnsOutbound = l2.ConnsOutbound
}
if l.Memory == 0 {
l.Memory = l2.Memory
}
if l.FDFraction == 0 {
l.FDFraction = l2.FDFraction
}
}
func (l BaseLimit) GetStreamLimit(dir network.Direction) int {
if dir == network.DirInbound {
return l.StreamsInbound
} else {
return l.StreamsOutbound
}
}
func (l BaseLimit) GetStreamTotalLimit() int {
return l.Streams
}
func (l BaseLimit) GetConnLimit(dir network.Direction) int {
if dir == network.DirInbound {
return l.ConnsInbound
} else {
return l.ConnsOutbound
}
}
func (l BaseLimit) GetConnTotalLimit() int {
return l.Conns
}
func (l BaseLimit) GetFDLimit() int {
return l.FD
}
func (l BaseLimit) GetMemoryLimit() int64 {
return l.Memory
}
func (l *fixedLimiter) GetSystemLimits() Limit {
return &l.system
}
func (l *fixedLimiter) GetTransientLimits() Limit {
return &l.transient
}
func (l *fixedLimiter) GetAllowlistedSystemLimits() Limit {
return &l.allowlistedSystem
}
func (l *fixedLimiter) GetAllowlistedTransientLimits() Limit {
return &l.allowlistedTransient
}
func (l *fixedLimiter) GetServiceLimits(svc string) Limit {
sl, ok := l.service[svc]
if !ok {
return &l.serviceDefault
}
return &sl
}
func (l *fixedLimiter) GetServicePeerLimits(svc string) Limit {
pl, ok := l.servicePeer[svc]
if !ok {
return &l.servicePeerDefault
}
return &pl
}
func (l *fixedLimiter) GetProtocolLimits(proto protocol.ID) Limit {
pl, ok := l.protocol[proto]
if !ok {
return &l.protocolDefault
}
return &pl
}
func (l *fixedLimiter) GetProtocolPeerLimits(proto protocol.ID) Limit {
pl, ok := l.protocolPeer[proto]
if !ok {
return &l.protocolPeerDefault
}
return &pl
}
func (l *fixedLimiter) GetPeerLimits(p peer.ID) Limit {
pl, ok := l.peer[p]
if !ok {
return &l.peerDefault
}
return &pl
}
func (l *fixedLimiter) GetStreamLimits(_ peer.ID) Limit {
return &l.stream
}
func (l *fixedLimiter) GetConnLimits() Limit {
return &l.conn
}

View File

@@ -0,0 +1,45 @@
{
"System": {
"Memory": 65536,
"Conns": 16,
"ConnsInbound": 8,
"ConnsOutbound": 16,
"FD": 16
},
"ServiceDefault": {
"Memory": 8765
},
"Service": {
"A": {
"Memory": 8192
},
"B": {}
},
"ServicePeerDefault": {
"Memory": 2048
},
"ServicePeer": {
"A": {
"Memory": 4096
}
},
"ProtocolDefault": {
"Memory": 2048
},
"ProtocolPeerDefault": {
"Memory": 1024
},
"Protocol": {
"/A": {
"Memory": 8192
}
},
"PeerDefault": {
"Memory": 4096
},
"Peer": {
"12D3KooWPFH2Bx2tPfw6RLxN8k2wh47GRXgkt9yrAHU37zFwHWzS": {
"Memory": 4097
}
}
}

View File

@@ -0,0 +1,45 @@
{
"System": {
"Memory": 65536,
"Conns": 16,
"ConnsInbound": 8,
"ConnsOutbound": 16,
"FD": 16
},
"ServiceDefault": {
"Memory": 8765
},
"Service": {
"A": {
"Memory": 8192
},
"B": {}
},
"ServicePeerDefault": {
"Memory": 2048
},
"ServicePeer": {
"A": {
"Memory": 4096
}
},
"ProtocolDefault": {
"Memory": 2048
},
"ProtocolPeerDefault": {
"Memory": 1024
},
"Protocol": {
"/A": {
"Memory": 8192
}
},
"PeerDefault": {
"Memory": 4096
},
"Peer": {
"12D3KooWPFH2Bx2tPfw6RLxN8k2wh47GRXgkt9yrAHU37zFwHWzS": {
"Memory": 4097
}
}
}

View File

@@ -0,0 +1,112 @@
{
"System": {
"Streams": 18432,
"StreamsInbound": 9216,
"StreamsOutbound": 18432,
"Conns": 1152,
"ConnsInbound": 576,
"ConnsOutbound": 1152,
"FD": 16384,
"Memory": "8724152320"
},
"Transient": {
"Streams": 2304,
"StreamsInbound": 1152,
"StreamsOutbound": 2304,
"Conns": 320,
"ConnsInbound": 160,
"ConnsOutbound": 320,
"FD": 4096,
"Memory": "1107296256"
},
"AllowlistedSystem": {
"Streams": 18432,
"StreamsInbound": 9216,
"StreamsOutbound": 18432,
"Conns": 1152,
"ConnsInbound": 576,
"ConnsOutbound": 1152,
"FD": 16384,
"Memory": "8724152320"
},
"AllowlistedTransient": {
"Streams": 2304,
"StreamsInbound": 1152,
"StreamsOutbound": 2304,
"Conns": 320,
"ConnsInbound": 160,
"ConnsOutbound": 320,
"FD": 4096,
"Memory": "1107296256"
},
"ServiceDefault": {
"Streams": 20480,
"StreamsInbound": 5120,
"StreamsOutbound": 20480,
"Conns": "blockAll",
"ConnsInbound": "blockAll",
"ConnsOutbound": "blockAll",
"FD": "blockAll",
"Memory": "1140850688"
},
"ServicePeerDefault": {
"Streams": 320,
"StreamsInbound": 160,
"StreamsOutbound": 320,
"Conns": "blockAll",
"ConnsInbound": "blockAll",
"ConnsOutbound": "blockAll",
"FD": "blockAll",
"Memory": "50331648"
},
"ProtocolDefault": {
"Streams": 6144,
"StreamsInbound": 2560,
"StreamsOutbound": 6144,
"Conns": "blockAll",
"ConnsInbound": "blockAll",
"ConnsOutbound": "blockAll",
"FD": "blockAll",
"Memory": "1442840576"
},
"ProtocolPeerDefault": {
"Streams": 384,
"StreamsInbound": 96,
"StreamsOutbound": 192,
"Conns": "blockAll",
"ConnsInbound": "blockAll",
"ConnsOutbound": "blockAll",
"FD": "blockAll",
"Memory": "16777248"
},
"PeerDefault": {
"Streams": 2560,
"StreamsInbound": 1280,
"StreamsOutbound": 2560,
"Conns": 8,
"ConnsInbound": 8,
"ConnsOutbound": 8,
"FD": 256,
"Memory": "1140850688"
},
"Conn": {
"Streams": "blockAll",
"StreamsInbound": "blockAll",
"StreamsOutbound": "blockAll",
"Conns": 1,
"ConnsInbound": 1,
"ConnsOutbound": 1,
"FD": 1,
"Memory": "33554432"
},
"Stream": {
"Streams": 1,
"StreamsInbound": 1,
"StreamsOutbound": 1,
"Conns": "blockAll",
"ConnsInbound": "blockAll",
"ConnsOutbound": "blockAll",
"FD": "blockAll",
"Memory": "16777216"
}
}

View File

@@ -0,0 +1,879 @@
package rcmgr
import (
"encoding/json"
"fmt"
"math"
"strconv"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/pbnjay/memory"
)
type baseLimitConfig struct {
BaseLimit BaseLimit
BaseLimitIncrease BaseLimitIncrease
}
// ScalingLimitConfig is a struct for configuring default limits.
// {}BaseLimit is the limits that Apply for a minimal node (128 MB of memory for libp2p) and 256 file descriptors.
// {}LimitIncrease is the additional limit granted for every additional 1 GB of RAM.
type ScalingLimitConfig struct {
SystemBaseLimit BaseLimit
SystemLimitIncrease BaseLimitIncrease
TransientBaseLimit BaseLimit
TransientLimitIncrease BaseLimitIncrease
AllowlistedSystemBaseLimit BaseLimit
AllowlistedSystemLimitIncrease BaseLimitIncrease
AllowlistedTransientBaseLimit BaseLimit
AllowlistedTransientLimitIncrease BaseLimitIncrease
ServiceBaseLimit BaseLimit
ServiceLimitIncrease BaseLimitIncrease
ServiceLimits map[string]baseLimitConfig // use AddServiceLimit to modify
ServicePeerBaseLimit BaseLimit
ServicePeerLimitIncrease BaseLimitIncrease
ServicePeerLimits map[string]baseLimitConfig // use AddServicePeerLimit to modify
ProtocolBaseLimit BaseLimit
ProtocolLimitIncrease BaseLimitIncrease
ProtocolLimits map[protocol.ID]baseLimitConfig // use AddProtocolLimit to modify
ProtocolPeerBaseLimit BaseLimit
ProtocolPeerLimitIncrease BaseLimitIncrease
ProtocolPeerLimits map[protocol.ID]baseLimitConfig // use AddProtocolPeerLimit to modify
PeerBaseLimit BaseLimit
PeerLimitIncrease BaseLimitIncrease
PeerLimits map[peer.ID]baseLimitConfig // use AddPeerLimit to modify
ConnBaseLimit BaseLimit
ConnLimitIncrease BaseLimitIncrease
StreamBaseLimit BaseLimit
StreamLimitIncrease BaseLimitIncrease
}
func (cfg *ScalingLimitConfig) AddServiceLimit(svc string, base BaseLimit, inc BaseLimitIncrease) {
if cfg.ServiceLimits == nil {
cfg.ServiceLimits = make(map[string]baseLimitConfig)
}
cfg.ServiceLimits[svc] = baseLimitConfig{
BaseLimit: base,
BaseLimitIncrease: inc,
}
}
func (cfg *ScalingLimitConfig) AddProtocolLimit(proto protocol.ID, base BaseLimit, inc BaseLimitIncrease) {
if cfg.ProtocolLimits == nil {
cfg.ProtocolLimits = make(map[protocol.ID]baseLimitConfig)
}
cfg.ProtocolLimits[proto] = baseLimitConfig{
BaseLimit: base,
BaseLimitIncrease: inc,
}
}
func (cfg *ScalingLimitConfig) AddPeerLimit(p peer.ID, base BaseLimit, inc BaseLimitIncrease) {
if cfg.PeerLimits == nil {
cfg.PeerLimits = make(map[peer.ID]baseLimitConfig)
}
cfg.PeerLimits[p] = baseLimitConfig{
BaseLimit: base,
BaseLimitIncrease: inc,
}
}
func (cfg *ScalingLimitConfig) AddServicePeerLimit(svc string, base BaseLimit, inc BaseLimitIncrease) {
if cfg.ServicePeerLimits == nil {
cfg.ServicePeerLimits = make(map[string]baseLimitConfig)
}
cfg.ServicePeerLimits[svc] = baseLimitConfig{
BaseLimit: base,
BaseLimitIncrease: inc,
}
}
func (cfg *ScalingLimitConfig) AddProtocolPeerLimit(proto protocol.ID, base BaseLimit, inc BaseLimitIncrease) {
if cfg.ProtocolPeerLimits == nil {
cfg.ProtocolPeerLimits = make(map[protocol.ID]baseLimitConfig)
}
cfg.ProtocolPeerLimits[proto] = baseLimitConfig{
BaseLimit: base,
BaseLimitIncrease: inc,
}
}
type LimitVal int
const (
// DefaultLimit is the default value for resources. The exact value depends on the context, but will get values from `DefaultLimits`.
DefaultLimit LimitVal = 0
// Unlimited is the value for unlimited resources. An arbitrarily high number will also work.
Unlimited LimitVal = -1
// BlockAllLimit is the LimitVal for allowing no amount of resources.
BlockAllLimit LimitVal = -2
)
func (l LimitVal) MarshalJSON() ([]byte, error) {
if l == Unlimited {
return json.Marshal("unlimited")
} else if l == DefaultLimit {
return json.Marshal("default")
} else if l == BlockAllLimit {
return json.Marshal("blockAll")
}
return json.Marshal(int(l))
}
func (l *LimitVal) UnmarshalJSON(b []byte) error {
if string(b) == `"default"` {
*l = DefaultLimit
return nil
} else if string(b) == `"unlimited"` {
*l = Unlimited
return nil
} else if string(b) == `"blockAll"` {
*l = BlockAllLimit
return nil
}
var val int
if err := json.Unmarshal(b, &val); err != nil {
return err
}
if val == 0 {
// If there is an explicit 0 in the JSON we should interpret this as block all.
*l = BlockAllLimit
return nil
}
*l = LimitVal(val)
return nil
}
func (l LimitVal) Build(defaultVal int) int {
if l == DefaultLimit {
return defaultVal
}
if l == Unlimited {
return math.MaxInt
}
if l == BlockAllLimit {
return 0
}
return int(l)
}
type LimitVal64 int64
const (
// Default is the default value for resources.
DefaultLimit64 LimitVal64 = 0
// Unlimited is the value for unlimited resources.
Unlimited64 LimitVal64 = -1
// BlockAllLimit64 is the LimitVal for allowing no amount of resources.
BlockAllLimit64 LimitVal64 = -2
)
func (l LimitVal64) MarshalJSON() ([]byte, error) {
if l == Unlimited64 {
return json.Marshal("unlimited")
} else if l == DefaultLimit64 {
return json.Marshal("default")
} else if l == BlockAllLimit64 {
return json.Marshal("blockAll")
}
// Convert this to a string because JSON doesn't support 64-bit integers.
return json.Marshal(strconv.FormatInt(int64(l), 10))
}
func (l *LimitVal64) UnmarshalJSON(b []byte) error {
if string(b) == `"default"` {
*l = DefaultLimit64
return nil
} else if string(b) == `"unlimited"` {
*l = Unlimited64
return nil
} else if string(b) == `"blockAll"` {
*l = BlockAllLimit64
return nil
}
var val string
if err := json.Unmarshal(b, &val); err != nil {
// Is this an integer? Possible because of backwards compatibility.
var val int
if err := json.Unmarshal(b, &val); err != nil {
return fmt.Errorf("failed to unmarshal limit value: %w", err)
}
if val == 0 {
// If there is an explicit 0 in the JSON we should interpret this as block all.
*l = BlockAllLimit64
return nil
}
*l = LimitVal64(val)
return nil
}
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return err
}
if i == 0 {
// If there is an explicit 0 in the JSON we should interpret this as block all.
*l = BlockAllLimit64
return nil
}
*l = LimitVal64(i)
return nil
}
func (l LimitVal64) Build(defaultVal int64) int64 {
if l == DefaultLimit64 {
return defaultVal
}
if l == Unlimited64 {
return math.MaxInt64
}
if l == BlockAllLimit64 {
return 0
}
return int64(l)
}
// ResourceLimits is the type for basic resource limits.
type ResourceLimits struct {
Streams LimitVal `json:",omitempty"`
StreamsInbound LimitVal `json:",omitempty"`
StreamsOutbound LimitVal `json:",omitempty"`
Conns LimitVal `json:",omitempty"`
ConnsInbound LimitVal `json:",omitempty"`
ConnsOutbound LimitVal `json:",omitempty"`
FD LimitVal `json:",omitempty"`
Memory LimitVal64 `json:",omitempty"`
}
func (l *ResourceLimits) IsDefault() bool {
if l == nil {
return true
}
if l.Streams == DefaultLimit &&
l.StreamsInbound == DefaultLimit &&
l.StreamsOutbound == DefaultLimit &&
l.Conns == DefaultLimit &&
l.ConnsInbound == DefaultLimit &&
l.ConnsOutbound == DefaultLimit &&
l.FD == DefaultLimit &&
l.Memory == DefaultLimit64 {
return true
}
return false
}
func (l *ResourceLimits) ToMaybeNilPtr() *ResourceLimits {
if l.IsDefault() {
return nil
}
return l
}
// Apply overwrites all default limits with the values of l2
func (l *ResourceLimits) Apply(l2 ResourceLimits) {
if l.Streams == DefaultLimit {
l.Streams = l2.Streams
}
if l.StreamsInbound == DefaultLimit {
l.StreamsInbound = l2.StreamsInbound
}
if l.StreamsOutbound == DefaultLimit {
l.StreamsOutbound = l2.StreamsOutbound
}
if l.Conns == DefaultLimit {
l.Conns = l2.Conns
}
if l.ConnsInbound == DefaultLimit {
l.ConnsInbound = l2.ConnsInbound
}
if l.ConnsOutbound == DefaultLimit {
l.ConnsOutbound = l2.ConnsOutbound
}
if l.FD == DefaultLimit {
l.FD = l2.FD
}
if l.Memory == DefaultLimit64 {
l.Memory = l2.Memory
}
}
func (l *ResourceLimits) Build(defaults Limit) BaseLimit {
if l == nil {
return BaseLimit{
Streams: defaults.GetStreamTotalLimit(),
StreamsInbound: defaults.GetStreamLimit(network.DirInbound),
StreamsOutbound: defaults.GetStreamLimit(network.DirOutbound),
Conns: defaults.GetConnTotalLimit(),
ConnsInbound: defaults.GetConnLimit(network.DirInbound),
ConnsOutbound: defaults.GetConnLimit(network.DirOutbound),
FD: defaults.GetFDLimit(),
Memory: defaults.GetMemoryLimit(),
}
}
return BaseLimit{
Streams: l.Streams.Build(defaults.GetStreamTotalLimit()),
StreamsInbound: l.StreamsInbound.Build(defaults.GetStreamLimit(network.DirInbound)),
StreamsOutbound: l.StreamsOutbound.Build(defaults.GetStreamLimit(network.DirOutbound)),
Conns: l.Conns.Build(defaults.GetConnTotalLimit()),
ConnsInbound: l.ConnsInbound.Build(defaults.GetConnLimit(network.DirInbound)),
ConnsOutbound: l.ConnsOutbound.Build(defaults.GetConnLimit(network.DirOutbound)),
FD: l.FD.Build(defaults.GetFDLimit()),
Memory: l.Memory.Build(defaults.GetMemoryLimit()),
}
}
type PartialLimitConfig struct {
System ResourceLimits `json:",omitempty"`
Transient ResourceLimits `json:",omitempty"`
// Limits that are applied to resources with an allowlisted multiaddr.
// These will only be used if the normal System & Transient limits are
// reached.
AllowlistedSystem ResourceLimits `json:",omitempty"`
AllowlistedTransient ResourceLimits `json:",omitempty"`
ServiceDefault ResourceLimits `json:",omitempty"`
Service map[string]ResourceLimits `json:",omitempty"`
ServicePeerDefault ResourceLimits `json:",omitempty"`
ServicePeer map[string]ResourceLimits `json:",omitempty"`
ProtocolDefault ResourceLimits `json:",omitempty"`
Protocol map[protocol.ID]ResourceLimits `json:",omitempty"`
ProtocolPeerDefault ResourceLimits `json:",omitempty"`
ProtocolPeer map[protocol.ID]ResourceLimits `json:",omitempty"`
PeerDefault ResourceLimits `json:",omitempty"`
Peer map[peer.ID]ResourceLimits `json:",omitempty"`
Conn ResourceLimits `json:",omitempty"`
Stream ResourceLimits `json:",omitempty"`
}
func (cfg *PartialLimitConfig) MarshalJSON() ([]byte, error) {
// we want to marshal the encoded peer id
encodedPeerMap := make(map[string]ResourceLimits, len(cfg.Peer))
for p, v := range cfg.Peer {
encodedPeerMap[p.String()] = v
}
type Alias PartialLimitConfig
return json.Marshal(&struct {
*Alias
// String so we can have the properly marshalled peer id
Peer map[string]ResourceLimits `json:",omitempty"`
// The rest of the fields as pointers so that we omit empty values in the serialized result
System *ResourceLimits `json:",omitempty"`
Transient *ResourceLimits `json:",omitempty"`
AllowlistedSystem *ResourceLimits `json:",omitempty"`
AllowlistedTransient *ResourceLimits `json:",omitempty"`
ServiceDefault *ResourceLimits `json:",omitempty"`
ServicePeerDefault *ResourceLimits `json:",omitempty"`
ProtocolDefault *ResourceLimits `json:",omitempty"`
ProtocolPeerDefault *ResourceLimits `json:",omitempty"`
PeerDefault *ResourceLimits `json:",omitempty"`
Conn *ResourceLimits `json:",omitempty"`
Stream *ResourceLimits `json:",omitempty"`
}{
Alias: (*Alias)(cfg),
Peer: encodedPeerMap,
System: cfg.System.ToMaybeNilPtr(),
Transient: cfg.Transient.ToMaybeNilPtr(),
AllowlistedSystem: cfg.AllowlistedSystem.ToMaybeNilPtr(),
AllowlistedTransient: cfg.AllowlistedTransient.ToMaybeNilPtr(),
ServiceDefault: cfg.ServiceDefault.ToMaybeNilPtr(),
ServicePeerDefault: cfg.ServicePeerDefault.ToMaybeNilPtr(),
ProtocolDefault: cfg.ProtocolDefault.ToMaybeNilPtr(),
ProtocolPeerDefault: cfg.ProtocolPeerDefault.ToMaybeNilPtr(),
PeerDefault: cfg.PeerDefault.ToMaybeNilPtr(),
Conn: cfg.Conn.ToMaybeNilPtr(),
Stream: cfg.Stream.ToMaybeNilPtr(),
})
}
func applyResourceLimitsMap[K comparable](this *map[K]ResourceLimits, other map[K]ResourceLimits, fallbackDefault ResourceLimits) {
for k, l := range *this {
r := fallbackDefault
if l2, ok := other[k]; ok {
r = l2
}
l.Apply(r)
(*this)[k] = l
}
if *this == nil && other != nil {
*this = make(map[K]ResourceLimits)
}
for k, l := range other {
if _, ok := (*this)[k]; !ok {
(*this)[k] = l
}
}
}
func (cfg *PartialLimitConfig) Apply(c PartialLimitConfig) {
cfg.System.Apply(c.System)
cfg.Transient.Apply(c.Transient)
cfg.AllowlistedSystem.Apply(c.AllowlistedSystem)
cfg.AllowlistedTransient.Apply(c.AllowlistedTransient)
cfg.ServiceDefault.Apply(c.ServiceDefault)
cfg.ServicePeerDefault.Apply(c.ServicePeerDefault)
cfg.ProtocolDefault.Apply(c.ProtocolDefault)
cfg.ProtocolPeerDefault.Apply(c.ProtocolPeerDefault)
cfg.PeerDefault.Apply(c.PeerDefault)
cfg.Conn.Apply(c.Conn)
cfg.Stream.Apply(c.Stream)
applyResourceLimitsMap(&cfg.Service, c.Service, cfg.ServiceDefault)
applyResourceLimitsMap(&cfg.ServicePeer, c.ServicePeer, cfg.ServicePeerDefault)
applyResourceLimitsMap(&cfg.Protocol, c.Protocol, cfg.ProtocolDefault)
applyResourceLimitsMap(&cfg.ProtocolPeer, c.ProtocolPeer, cfg.ProtocolPeerDefault)
applyResourceLimitsMap(&cfg.Peer, c.Peer, cfg.PeerDefault)
}
func (cfg PartialLimitConfig) Build(defaults ConcreteLimitConfig) ConcreteLimitConfig {
out := defaults
out.system = cfg.System.Build(defaults.system)
out.transient = cfg.Transient.Build(defaults.transient)
out.allowlistedSystem = cfg.AllowlistedSystem.Build(defaults.allowlistedSystem)
out.allowlistedTransient = cfg.AllowlistedTransient.Build(defaults.allowlistedTransient)
out.serviceDefault = cfg.ServiceDefault.Build(defaults.serviceDefault)
out.servicePeerDefault = cfg.ServicePeerDefault.Build(defaults.servicePeerDefault)
out.protocolDefault = cfg.ProtocolDefault.Build(defaults.protocolDefault)
out.protocolPeerDefault = cfg.ProtocolPeerDefault.Build(defaults.protocolPeerDefault)
out.peerDefault = cfg.PeerDefault.Build(defaults.peerDefault)
out.conn = cfg.Conn.Build(defaults.conn)
out.stream = cfg.Stream.Build(defaults.stream)
out.service = buildMapWithDefault(cfg.Service, defaults.service, out.serviceDefault)
out.servicePeer = buildMapWithDefault(cfg.ServicePeer, defaults.servicePeer, out.servicePeerDefault)
out.protocol = buildMapWithDefault(cfg.Protocol, defaults.protocol, out.protocolDefault)
out.protocolPeer = buildMapWithDefault(cfg.ProtocolPeer, defaults.protocolPeer, out.protocolPeerDefault)
out.peer = buildMapWithDefault(cfg.Peer, defaults.peer, out.peerDefault)
return out
}
func buildMapWithDefault[K comparable](definedLimits map[K]ResourceLimits, defaults map[K]BaseLimit, fallbackDefault BaseLimit) map[K]BaseLimit {
if definedLimits == nil && defaults == nil {
return nil
}
out := make(map[K]BaseLimit)
for k, l := range defaults {
out[k] = l
}
for k, l := range definedLimits {
if defaultForKey, ok := out[k]; ok {
out[k] = l.Build(defaultForKey)
} else {
out[k] = l.Build(fallbackDefault)
}
}
return out
}
// ConcreteLimitConfig is similar to PartialLimitConfig, but all values are defined.
// There is no unset "default" value. Commonly constructed by calling
// PartialLimitConfig.Build(rcmgr.DefaultLimits.AutoScale())
type ConcreteLimitConfig struct {
system BaseLimit
transient BaseLimit
// Limits that are applied to resources with an allowlisted multiaddr.
// These will only be used if the normal System & Transient limits are
// reached.
allowlistedSystem BaseLimit
allowlistedTransient BaseLimit
serviceDefault BaseLimit
service map[string]BaseLimit
servicePeerDefault BaseLimit
servicePeer map[string]BaseLimit
protocolDefault BaseLimit
protocol map[protocol.ID]BaseLimit
protocolPeerDefault BaseLimit
protocolPeer map[protocol.ID]BaseLimit
peerDefault BaseLimit
peer map[peer.ID]BaseLimit
conn BaseLimit
stream BaseLimit
}
func resourceLimitsMapFromBaseLimitMap[K comparable](baseLimitMap map[K]BaseLimit) map[K]ResourceLimits {
if baseLimitMap == nil {
return nil
}
out := make(map[K]ResourceLimits)
for k, l := range baseLimitMap {
out[k] = l.ToResourceLimits()
}
return out
}
// ToPartialLimitConfig converts a ConcreteLimitConfig to a PartialLimitConfig.
// The returned PartialLimitConfig will have no default values.
func (cfg ConcreteLimitConfig) ToPartialLimitConfig() PartialLimitConfig {
return PartialLimitConfig{
System: cfg.system.ToResourceLimits(),
Transient: cfg.transient.ToResourceLimits(),
AllowlistedSystem: cfg.allowlistedSystem.ToResourceLimits(),
AllowlistedTransient: cfg.allowlistedTransient.ToResourceLimits(),
ServiceDefault: cfg.serviceDefault.ToResourceLimits(),
Service: resourceLimitsMapFromBaseLimitMap(cfg.service),
ServicePeerDefault: cfg.servicePeerDefault.ToResourceLimits(),
ServicePeer: resourceLimitsMapFromBaseLimitMap(cfg.servicePeer),
ProtocolDefault: cfg.protocolDefault.ToResourceLimits(),
Protocol: resourceLimitsMapFromBaseLimitMap(cfg.protocol),
ProtocolPeerDefault: cfg.protocolPeerDefault.ToResourceLimits(),
ProtocolPeer: resourceLimitsMapFromBaseLimitMap(cfg.protocolPeer),
PeerDefault: cfg.peerDefault.ToResourceLimits(),
Peer: resourceLimitsMapFromBaseLimitMap(cfg.peer),
Conn: cfg.conn.ToResourceLimits(),
Stream: cfg.stream.ToResourceLimits(),
}
}
// Scale scales up a limit configuration.
// memory is the amount of memory that the stack is allowed to consume,
// for a dedicated node it's recommended to use 1/8 of the installed system memory.
// If memory is smaller than 128 MB, the base configuration will be used.
func (cfg *ScalingLimitConfig) Scale(memory int64, numFD int) ConcreteLimitConfig {
lc := ConcreteLimitConfig{
system: scale(cfg.SystemBaseLimit, cfg.SystemLimitIncrease, memory, numFD),
transient: scale(cfg.TransientBaseLimit, cfg.TransientLimitIncrease, memory, numFD),
allowlistedSystem: scale(cfg.AllowlistedSystemBaseLimit, cfg.AllowlistedSystemLimitIncrease, memory, numFD),
allowlistedTransient: scale(cfg.AllowlistedTransientBaseLimit, cfg.AllowlistedTransientLimitIncrease, memory, numFD),
serviceDefault: scale(cfg.ServiceBaseLimit, cfg.ServiceLimitIncrease, memory, numFD),
servicePeerDefault: scale(cfg.ServicePeerBaseLimit, cfg.ServicePeerLimitIncrease, memory, numFD),
protocolDefault: scale(cfg.ProtocolBaseLimit, cfg.ProtocolLimitIncrease, memory, numFD),
protocolPeerDefault: scale(cfg.ProtocolPeerBaseLimit, cfg.ProtocolPeerLimitIncrease, memory, numFD),
peerDefault: scale(cfg.PeerBaseLimit, cfg.PeerLimitIncrease, memory, numFD),
conn: scale(cfg.ConnBaseLimit, cfg.ConnLimitIncrease, memory, numFD),
stream: scale(cfg.StreamBaseLimit, cfg.ConnLimitIncrease, memory, numFD),
}
if cfg.ServiceLimits != nil {
lc.service = make(map[string]BaseLimit)
for svc, l := range cfg.ServiceLimits {
lc.service[svc] = scale(l.BaseLimit, l.BaseLimitIncrease, memory, numFD)
}
}
if cfg.ProtocolLimits != nil {
lc.protocol = make(map[protocol.ID]BaseLimit)
for proto, l := range cfg.ProtocolLimits {
lc.protocol[proto] = scale(l.BaseLimit, l.BaseLimitIncrease, memory, numFD)
}
}
if cfg.PeerLimits != nil {
lc.peer = make(map[peer.ID]BaseLimit)
for p, l := range cfg.PeerLimits {
lc.peer[p] = scale(l.BaseLimit, l.BaseLimitIncrease, memory, numFD)
}
}
if cfg.ServicePeerLimits != nil {
lc.servicePeer = make(map[string]BaseLimit)
for svc, l := range cfg.ServicePeerLimits {
lc.servicePeer[svc] = scale(l.BaseLimit, l.BaseLimitIncrease, memory, numFD)
}
}
if cfg.ProtocolPeerLimits != nil {
lc.protocolPeer = make(map[protocol.ID]BaseLimit)
for p, l := range cfg.ProtocolPeerLimits {
lc.protocolPeer[p] = scale(l.BaseLimit, l.BaseLimitIncrease, memory, numFD)
}
}
return lc
}
func (cfg *ScalingLimitConfig) AutoScale() ConcreteLimitConfig {
return cfg.Scale(
int64(memory.TotalMemory())/8,
getNumFDs()/2,
)
}
func scale(base BaseLimit, inc BaseLimitIncrease, memory int64, numFD int) BaseLimit {
// mebibytesAvailable represents how many MiBs we're allowed to use. Used to
// scale the limits. If this is below 128MiB we set it to 0 to just use the
// base amounts.
var mebibytesAvailable int
if memory > 128<<20 {
mebibytesAvailable = int((memory) >> 20)
}
l := BaseLimit{
StreamsInbound: base.StreamsInbound + (inc.StreamsInbound*mebibytesAvailable)>>10,
StreamsOutbound: base.StreamsOutbound + (inc.StreamsOutbound*mebibytesAvailable)>>10,
Streams: base.Streams + (inc.Streams*mebibytesAvailable)>>10,
ConnsInbound: base.ConnsInbound + (inc.ConnsInbound*mebibytesAvailable)>>10,
ConnsOutbound: base.ConnsOutbound + (inc.ConnsOutbound*mebibytesAvailable)>>10,
Conns: base.Conns + (inc.Conns*mebibytesAvailable)>>10,
Memory: base.Memory + (inc.Memory*int64(mebibytesAvailable))>>10,
FD: base.FD,
}
if inc.FDFraction > 0 && numFD > 0 {
l.FD = int(inc.FDFraction * float64(numFD))
if l.FD < base.FD {
// Use at least the base amount
l.FD = base.FD
}
}
return l
}
// DefaultLimits are the limits used by the default limiter constructors.
var DefaultLimits = ScalingLimitConfig{
SystemBaseLimit: BaseLimit{
ConnsInbound: 64,
ConnsOutbound: 128,
Conns: 128,
StreamsInbound: 64 * 16,
StreamsOutbound: 128 * 16,
Streams: 128 * 16,
Memory: 128 << 20,
FD: 256,
},
SystemLimitIncrease: BaseLimitIncrease{
ConnsInbound: 64,
ConnsOutbound: 128,
Conns: 128,
StreamsInbound: 64 * 16,
StreamsOutbound: 128 * 16,
Streams: 128 * 16,
Memory: 1 << 30,
FDFraction: 1,
},
TransientBaseLimit: BaseLimit{
ConnsInbound: 32,
ConnsOutbound: 64,
Conns: 64,
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 32 << 20,
FD: 64,
},
TransientLimitIncrease: BaseLimitIncrease{
ConnsInbound: 16,
ConnsOutbound: 32,
Conns: 32,
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 128 << 20,
FDFraction: 0.25,
},
// Setting the allowlisted limits to be the same as the normal limits. The
// allowlist only activates when you reach your normal system/transient
// limits. So it's okay if these limits err on the side of being too big,
// since most of the time you won't even use any of these. Tune these down
// if you want to manage your resources against an allowlisted endpoint.
AllowlistedSystemBaseLimit: BaseLimit{
ConnsInbound: 64,
ConnsOutbound: 128,
Conns: 128,
StreamsInbound: 64 * 16,
StreamsOutbound: 128 * 16,
Streams: 128 * 16,
Memory: 128 << 20,
FD: 256,
},
AllowlistedSystemLimitIncrease: BaseLimitIncrease{
ConnsInbound: 64,
ConnsOutbound: 128,
Conns: 128,
StreamsInbound: 64 * 16,
StreamsOutbound: 128 * 16,
Streams: 128 * 16,
Memory: 1 << 30,
FDFraction: 1,
},
AllowlistedTransientBaseLimit: BaseLimit{
ConnsInbound: 32,
ConnsOutbound: 64,
Conns: 64,
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 32 << 20,
FD: 64,
},
AllowlistedTransientLimitIncrease: BaseLimitIncrease{
ConnsInbound: 16,
ConnsOutbound: 32,
Conns: 32,
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 128 << 20,
FDFraction: 0.25,
},
ServiceBaseLimit: BaseLimit{
StreamsInbound: 1024,
StreamsOutbound: 4096,
Streams: 4096,
Memory: 64 << 20,
},
ServiceLimitIncrease: BaseLimitIncrease{
StreamsInbound: 512,
StreamsOutbound: 2048,
Streams: 2048,
Memory: 128 << 20,
},
ServicePeerBaseLimit: BaseLimit{
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 16 << 20,
},
ServicePeerLimitIncrease: BaseLimitIncrease{
StreamsInbound: 4,
StreamsOutbound: 8,
Streams: 8,
Memory: 4 << 20,
},
ProtocolBaseLimit: BaseLimit{
StreamsInbound: 512,
StreamsOutbound: 2048,
Streams: 2048,
Memory: 64 << 20,
},
ProtocolLimitIncrease: BaseLimitIncrease{
StreamsInbound: 256,
StreamsOutbound: 512,
Streams: 512,
Memory: 164 << 20,
},
ProtocolPeerBaseLimit: BaseLimit{
StreamsInbound: 64,
StreamsOutbound: 128,
Streams: 256,
Memory: 16 << 20,
},
ProtocolPeerLimitIncrease: BaseLimitIncrease{
StreamsInbound: 4,
StreamsOutbound: 8,
Streams: 16,
Memory: 4,
},
PeerBaseLimit: BaseLimit{
// 8 for now so that it matches the number of concurrent dials we may do
// in swarm_dial.go. With future smart dialing work we should bring this
// down
ConnsInbound: 8,
ConnsOutbound: 8,
Conns: 8,
StreamsInbound: 256,
StreamsOutbound: 512,
Streams: 512,
Memory: 64 << 20,
FD: 4,
},
PeerLimitIncrease: BaseLimitIncrease{
StreamsInbound: 128,
StreamsOutbound: 256,
Streams: 256,
Memory: 128 << 20,
FDFraction: 1.0 / 64,
},
ConnBaseLimit: BaseLimit{
ConnsInbound: 1,
ConnsOutbound: 1,
Conns: 1,
FD: 1,
Memory: 32 << 20,
},
StreamBaseLimit: BaseLimit{
StreamsInbound: 1,
StreamsOutbound: 1,
Streams: 1,
Memory: 16 << 20,
},
}
var infiniteBaseLimit = BaseLimit{
Streams: math.MaxInt,
StreamsInbound: math.MaxInt,
StreamsOutbound: math.MaxInt,
Conns: math.MaxInt,
ConnsInbound: math.MaxInt,
ConnsOutbound: math.MaxInt,
FD: math.MaxInt,
Memory: math.MaxInt64,
}
// InfiniteLimits are a limiter configuration that uses unlimited limits, thus effectively not limiting anything.
// Keep in mind that the operating system limits the number of file descriptors that an application can use.
var InfiniteLimits = ConcreteLimitConfig{
system: infiniteBaseLimit,
transient: infiniteBaseLimit,
allowlistedSystem: infiniteBaseLimit,
allowlistedTransient: infiniteBaseLimit,
serviceDefault: infiniteBaseLimit,
servicePeerDefault: infiniteBaseLimit,
protocolDefault: infiniteBaseLimit,
protocolPeerDefault: infiniteBaseLimit,
peerDefault: infiniteBaseLimit,
conn: infiniteBaseLimit,
stream: infiniteBaseLimit,
}

View File

@@ -0,0 +1,168 @@
package rcmgr
import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)
// MetricsReporter is an interface for collecting metrics from resource manager actions
type MetricsReporter interface {
// AllowConn is invoked when opening a connection is allowed
AllowConn(dir network.Direction, usefd bool)
// BlockConn is invoked when opening a connection is blocked
BlockConn(dir network.Direction, usefd bool)
// AllowStream is invoked when opening a stream is allowed
AllowStream(p peer.ID, dir network.Direction)
// BlockStream is invoked when opening a stream is blocked
BlockStream(p peer.ID, dir network.Direction)
// AllowPeer is invoked when attaching ac onnection to a peer is allowed
AllowPeer(p peer.ID)
// BlockPeer is invoked when attaching ac onnection to a peer is blocked
BlockPeer(p peer.ID)
// AllowProtocol is invoked when setting the protocol for a stream is allowed
AllowProtocol(proto protocol.ID)
// BlockProtocol is invoked when setting the protocol for a stream is blocked
BlockProtocol(proto protocol.ID)
// BlockProtocolPeer is invoked when setting the protocol for a stream is blocked at the per protocol peer scope
BlockProtocolPeer(proto protocol.ID, p peer.ID)
// AllowService is invoked when setting the protocol for a stream is allowed
AllowService(svc string)
// BlockService is invoked when setting the protocol for a stream is blocked
BlockService(svc string)
// BlockServicePeer is invoked when setting the service for a stream is blocked at the per service peer scope
BlockServicePeer(svc string, p peer.ID)
// AllowMemory is invoked when a memory reservation is allowed
AllowMemory(size int)
// BlockMemory is invoked when a memory reservation is blocked
BlockMemory(size int)
}
type metrics struct {
reporter MetricsReporter
}
// WithMetrics is a resource manager option to enable metrics collection
func WithMetrics(reporter MetricsReporter) Option {
return func(r *resourceManager) error {
r.metrics = &metrics{reporter: reporter}
return nil
}
}
func (m *metrics) AllowConn(dir network.Direction, usefd bool) {
if m == nil {
return
}
m.reporter.AllowConn(dir, usefd)
}
func (m *metrics) BlockConn(dir network.Direction, usefd bool) {
if m == nil {
return
}
m.reporter.BlockConn(dir, usefd)
}
func (m *metrics) AllowStream(p peer.ID, dir network.Direction) {
if m == nil {
return
}
m.reporter.AllowStream(p, dir)
}
func (m *metrics) BlockStream(p peer.ID, dir network.Direction) {
if m == nil {
return
}
m.reporter.BlockStream(p, dir)
}
func (m *metrics) AllowPeer(p peer.ID) {
if m == nil {
return
}
m.reporter.AllowPeer(p)
}
func (m *metrics) BlockPeer(p peer.ID) {
if m == nil {
return
}
m.reporter.BlockPeer(p)
}
func (m *metrics) AllowProtocol(proto protocol.ID) {
if m == nil {
return
}
m.reporter.AllowProtocol(proto)
}
func (m *metrics) BlockProtocol(proto protocol.ID) {
if m == nil {
return
}
m.reporter.BlockProtocol(proto)
}
func (m *metrics) BlockProtocolPeer(proto protocol.ID, p peer.ID) {
if m == nil {
return
}
m.reporter.BlockProtocolPeer(proto, p)
}
func (m *metrics) AllowService(svc string) {
if m == nil {
return
}
m.reporter.AllowService(svc)
}
func (m *metrics) BlockService(svc string) {
if m == nil {
return
}
m.reporter.BlockService(svc)
}
func (m *metrics) BlockServicePeer(svc string, p peer.ID) {
if m == nil {
return
}
m.reporter.BlockServicePeer(svc, p)
}
func (m *metrics) AllowMemory(size int) {
if m == nil {
return
}
m.reporter.AllowMemory(size)
}
func (m *metrics) BlockMemory(size int) {
if m == nil {
return
}
m.reporter.BlockMemory(size)
}

View File

@@ -0,0 +1,878 @@
package rcmgr
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
logging "github.com/ipfs/go-log/v2"
"github.com/multiformats/go-multiaddr"
)
var log = logging.Logger("rcmgr")
type resourceManager struct {
limits Limiter
trace *trace
metrics *metrics
disableMetrics bool
allowlist *Allowlist
system *systemScope
transient *transientScope
allowlistedSystem *systemScope
allowlistedTransient *transientScope
cancelCtx context.Context
cancel func()
wg sync.WaitGroup
mx sync.Mutex
svc map[string]*serviceScope
proto map[protocol.ID]*protocolScope
peer map[peer.ID]*peerScope
stickyProto map[protocol.ID]struct{}
stickyPeer map[peer.ID]struct{}
connId, streamId int64
}
var _ network.ResourceManager = (*resourceManager)(nil)
type systemScope struct {
*resourceScope
}
var _ network.ResourceScope = (*systemScope)(nil)
type transientScope struct {
*resourceScope
system *systemScope
}
var _ network.ResourceScope = (*transientScope)(nil)
type serviceScope struct {
*resourceScope
service string
rcmgr *resourceManager
peers map[peer.ID]*resourceScope
}
var _ network.ServiceScope = (*serviceScope)(nil)
type protocolScope struct {
*resourceScope
proto protocol.ID
rcmgr *resourceManager
peers map[peer.ID]*resourceScope
}
var _ network.ProtocolScope = (*protocolScope)(nil)
type peerScope struct {
*resourceScope
peer peer.ID
rcmgr *resourceManager
}
var _ network.PeerScope = (*peerScope)(nil)
type connectionScope struct {
*resourceScope
dir network.Direction
usefd bool
isAllowlisted bool
rcmgr *resourceManager
peer *peerScope
endpoint multiaddr.Multiaddr
}
var _ network.ConnScope = (*connectionScope)(nil)
var _ network.ConnManagementScope = (*connectionScope)(nil)
type streamScope struct {
*resourceScope
dir network.Direction
rcmgr *resourceManager
peer *peerScope
svc *serviceScope
proto *protocolScope
peerProtoScope *resourceScope
peerSvcScope *resourceScope
}
var _ network.StreamScope = (*streamScope)(nil)
var _ network.StreamManagementScope = (*streamScope)(nil)
type Option func(*resourceManager) error
func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager, error) {
allowlist := newAllowlist()
r := &resourceManager{
limits: limits,
allowlist: &allowlist,
svc: make(map[string]*serviceScope),
proto: make(map[protocol.ID]*protocolScope),
peer: make(map[peer.ID]*peerScope),
}
for _, opt := range opts {
if err := opt(r); err != nil {
return nil, err
}
}
if !r.disableMetrics {
var sr TraceReporter
sr, err := NewStatsTraceReporter()
if err != nil {
log.Errorf("failed to initialise StatsTraceReporter %s", err)
} else {
if r.trace == nil {
r.trace = &trace{}
}
found := false
for _, rep := range r.trace.reporters {
if rep == sr {
found = true
break
}
}
if !found {
r.trace.reporters = append(r.trace.reporters, sr)
}
}
}
if err := r.trace.Start(limits); err != nil {
return nil, err
}
r.system = newSystemScope(limits.GetSystemLimits(), r, "system")
r.system.IncRef()
r.transient = newTransientScope(limits.GetTransientLimits(), r, "transient", r.system.resourceScope)
r.transient.IncRef()
r.allowlistedSystem = newSystemScope(limits.GetAllowlistedSystemLimits(), r, "allowlistedSystem")
r.allowlistedSystem.IncRef()
r.allowlistedTransient = newTransientScope(limits.GetAllowlistedTransientLimits(), r, "allowlistedTransient", r.allowlistedSystem.resourceScope)
r.allowlistedTransient.IncRef()
r.cancelCtx, r.cancel = context.WithCancel(context.Background())
r.wg.Add(1)
go r.background()
return r, nil
}
func (r *resourceManager) GetAllowlist() *Allowlist {
return r.allowlist
}
// GetAllowlist tries to get the allowlist from the given resourcemanager
// interface by checking to see if its concrete type is a resourceManager.
// Returns nil if it fails to get the allowlist.
func GetAllowlist(rcmgr network.ResourceManager) *Allowlist {
r, ok := rcmgr.(*resourceManager)
if !ok {
return nil
}
return r.allowlist
}
func (r *resourceManager) ViewSystem(f func(network.ResourceScope) error) error {
return f(r.system)
}
func (r *resourceManager) ViewTransient(f func(network.ResourceScope) error) error {
return f(r.transient)
}
func (r *resourceManager) ViewService(srv string, f func(network.ServiceScope) error) error {
s := r.getServiceScope(srv)
defer s.DecRef()
return f(s)
}
func (r *resourceManager) ViewProtocol(proto protocol.ID, f func(network.ProtocolScope) error) error {
s := r.getProtocolScope(proto)
defer s.DecRef()
return f(s)
}
func (r *resourceManager) ViewPeer(p peer.ID, f func(network.PeerScope) error) error {
s := r.getPeerScope(p)
defer s.DecRef()
return f(s)
}
func (r *resourceManager) getServiceScope(svc string) *serviceScope {
r.mx.Lock()
defer r.mx.Unlock()
s, ok := r.svc[svc]
if !ok {
s = newServiceScope(svc, r.limits.GetServiceLimits(svc), r)
r.svc[svc] = s
}
s.IncRef()
return s
}
func (r *resourceManager) getProtocolScope(proto protocol.ID) *protocolScope {
r.mx.Lock()
defer r.mx.Unlock()
s, ok := r.proto[proto]
if !ok {
s = newProtocolScope(proto, r.limits.GetProtocolLimits(proto), r)
r.proto[proto] = s
}
s.IncRef()
return s
}
func (r *resourceManager) setStickyProtocol(proto protocol.ID) {
r.mx.Lock()
defer r.mx.Unlock()
if r.stickyProto == nil {
r.stickyProto = make(map[protocol.ID]struct{})
}
r.stickyProto[proto] = struct{}{}
}
func (r *resourceManager) getPeerScope(p peer.ID) *peerScope {
r.mx.Lock()
defer r.mx.Unlock()
s, ok := r.peer[p]
if !ok {
s = newPeerScope(p, r.limits.GetPeerLimits(p), r)
r.peer[p] = s
}
s.IncRef()
return s
}
func (r *resourceManager) setStickyPeer(p peer.ID) {
r.mx.Lock()
defer r.mx.Unlock()
if r.stickyPeer == nil {
r.stickyPeer = make(map[peer.ID]struct{})
}
r.stickyPeer[p] = struct{}{}
}
func (r *resourceManager) nextConnId() int64 {
r.mx.Lock()
defer r.mx.Unlock()
r.connId++
return r.connId
}
func (r *resourceManager) nextStreamId() int64 {
r.mx.Lock()
defer r.mx.Unlock()
r.streamId++
return r.streamId
}
func (r *resourceManager) OpenConnection(dir network.Direction, usefd bool, endpoint multiaddr.Multiaddr) (network.ConnManagementScope, error) {
var conn *connectionScope
conn = newConnectionScope(dir, usefd, r.limits.GetConnLimits(), r, endpoint)
err := conn.AddConn(dir, usefd)
if err != nil {
// Try again if this is an allowlisted connection
// Failed to open connection, let's see if this was allowlisted and try again
allowed := r.allowlist.Allowed(endpoint)
if allowed {
conn.Done()
conn = newAllowListedConnectionScope(dir, usefd, r.limits.GetConnLimits(), r, endpoint)
err = conn.AddConn(dir, usefd)
}
}
if err != nil {
conn.Done()
r.metrics.BlockConn(dir, usefd)
return nil, err
}
r.metrics.AllowConn(dir, usefd)
return conn, nil
}
func (r *resourceManager) OpenStream(p peer.ID, dir network.Direction) (network.StreamManagementScope, error) {
peer := r.getPeerScope(p)
stream := newStreamScope(dir, r.limits.GetStreamLimits(p), peer, r)
peer.DecRef() // we have the reference in edges
err := stream.AddStream(dir)
if err != nil {
stream.Done()
r.metrics.BlockStream(p, dir)
return nil, err
}
r.metrics.AllowStream(p, dir)
return stream, nil
}
func (r *resourceManager) Close() error {
r.cancel()
r.wg.Wait()
r.trace.Close()
return nil
}
func (r *resourceManager) background() {
defer r.wg.Done()
// periodically garbage collects unused peer and protocol scopes
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
r.gc()
case <-r.cancelCtx.Done():
return
}
}
}
func (r *resourceManager) gc() {
r.mx.Lock()
defer r.mx.Unlock()
for proto, s := range r.proto {
_, sticky := r.stickyProto[proto]
if sticky {
continue
}
if s.IsUnused() {
s.Done()
delete(r.proto, proto)
}
}
var deadPeers []peer.ID
for p, s := range r.peer {
_, sticky := r.stickyPeer[p]
if sticky {
continue
}
if s.IsUnused() {
s.Done()
delete(r.peer, p)
deadPeers = append(deadPeers, p)
}
}
for _, s := range r.svc {
s.Lock()
for _, p := range deadPeers {
ps, ok := s.peers[p]
if ok {
ps.Done()
delete(s.peers, p)
}
}
s.Unlock()
}
for _, s := range r.proto {
s.Lock()
for _, p := range deadPeers {
ps, ok := s.peers[p]
if ok {
ps.Done()
delete(s.peers, p)
}
}
s.Unlock()
}
}
func newSystemScope(limit Limit, rcmgr *resourceManager, name string) *systemScope {
return &systemScope{
resourceScope: newResourceScope(limit, nil, name, rcmgr.trace, rcmgr.metrics),
}
}
func newTransientScope(limit Limit, rcmgr *resourceManager, name string, systemScope *resourceScope) *transientScope {
return &transientScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{systemScope},
name, rcmgr.trace, rcmgr.metrics),
system: rcmgr.system,
}
}
func newServiceScope(service string, limit Limit, rcmgr *resourceManager) *serviceScope {
return &serviceScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{rcmgr.system.resourceScope},
fmt.Sprintf("service:%s", service), rcmgr.trace, rcmgr.metrics),
service: service,
rcmgr: rcmgr,
}
}
func newProtocolScope(proto protocol.ID, limit Limit, rcmgr *resourceManager) *protocolScope {
return &protocolScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{rcmgr.system.resourceScope},
fmt.Sprintf("protocol:%s", proto), rcmgr.trace, rcmgr.metrics),
proto: proto,
rcmgr: rcmgr,
}
}
func newPeerScope(p peer.ID, limit Limit, rcmgr *resourceManager) *peerScope {
return &peerScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{rcmgr.system.resourceScope},
peerScopeName(p), rcmgr.trace, rcmgr.metrics),
peer: p,
rcmgr: rcmgr,
}
}
func newConnectionScope(dir network.Direction, usefd bool, limit Limit, rcmgr *resourceManager, endpoint multiaddr.Multiaddr) *connectionScope {
return &connectionScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{rcmgr.transient.resourceScope, rcmgr.system.resourceScope},
connScopeName(rcmgr.nextConnId()), rcmgr.trace, rcmgr.metrics),
dir: dir,
usefd: usefd,
rcmgr: rcmgr,
endpoint: endpoint,
}
}
func newAllowListedConnectionScope(dir network.Direction, usefd bool, limit Limit, rcmgr *resourceManager, endpoint multiaddr.Multiaddr) *connectionScope {
return &connectionScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{rcmgr.allowlistedTransient.resourceScope, rcmgr.allowlistedSystem.resourceScope},
connScopeName(rcmgr.nextConnId()), rcmgr.trace, rcmgr.metrics),
dir: dir,
usefd: usefd,
rcmgr: rcmgr,
endpoint: endpoint,
isAllowlisted: true,
}
}
func newStreamScope(dir network.Direction, limit Limit, peer *peerScope, rcmgr *resourceManager) *streamScope {
return &streamScope{
resourceScope: newResourceScope(limit,
[]*resourceScope{peer.resourceScope, rcmgr.transient.resourceScope, rcmgr.system.resourceScope},
streamScopeName(rcmgr.nextStreamId()), rcmgr.trace, rcmgr.metrics),
dir: dir,
rcmgr: peer.rcmgr,
peer: peer,
}
}
func IsSystemScope(name string) bool {
return name == "system"
}
func IsTransientScope(name string) bool {
return name == "transient"
}
func streamScopeName(streamId int64) string {
return fmt.Sprintf("stream-%d", streamId)
}
func IsStreamScope(name string) bool {
return strings.HasPrefix(name, "stream-") && !IsSpan(name)
}
func connScopeName(streamId int64) string {
return fmt.Sprintf("conn-%d", streamId)
}
func IsConnScope(name string) bool {
return strings.HasPrefix(name, "conn-") && !IsSpan(name)
}
func peerScopeName(p peer.ID) string {
return fmt.Sprintf("peer:%s", p)
}
// PeerStrInScopeName returns "" if name is not a peerScopeName. Returns a string to avoid allocating a peer ID object
func PeerStrInScopeName(name string) string {
if !strings.HasPrefix(name, "peer:") || IsSpan(name) {
return ""
}
// Index to avoid allocating a new string
peerSplitIdx := strings.Index(name, "peer:")
if peerSplitIdx == -1 {
return ""
}
p := (name[peerSplitIdx+len("peer:"):])
return p
}
// ParseProtocolScopeName returns the service name if name is a serviceScopeName.
// Otherwise returns ""
func ParseProtocolScopeName(name string) string {
if strings.HasPrefix(name, "protocol:") && !IsSpan(name) {
if strings.Contains(name, "peer:") {
// This is a protocol peer scope
return ""
}
// Index to avoid allocating a new string
separatorIdx := strings.Index(name, ":")
if separatorIdx == -1 {
return ""
}
return name[separatorIdx+1:]
}
return ""
}
func (s *serviceScope) Name() string {
return s.service
}
func (s *serviceScope) getPeerScope(p peer.ID) *resourceScope {
s.Lock()
defer s.Unlock()
ps, ok := s.peers[p]
if ok {
ps.IncRef()
return ps
}
l := s.rcmgr.limits.GetServicePeerLimits(s.service)
if s.peers == nil {
s.peers = make(map[peer.ID]*resourceScope)
}
ps = newResourceScope(l, nil, fmt.Sprintf("%s.peer:%s", s.name, p), s.rcmgr.trace, s.rcmgr.metrics)
s.peers[p] = ps
ps.IncRef()
return ps
}
func (s *protocolScope) Protocol() protocol.ID {
return s.proto
}
func (s *protocolScope) getPeerScope(p peer.ID) *resourceScope {
s.Lock()
defer s.Unlock()
ps, ok := s.peers[p]
if ok {
ps.IncRef()
return ps
}
l := s.rcmgr.limits.GetProtocolPeerLimits(s.proto)
if s.peers == nil {
s.peers = make(map[peer.ID]*resourceScope)
}
ps = newResourceScope(l, nil, fmt.Sprintf("%s.peer:%s", s.name, p), s.rcmgr.trace, s.rcmgr.metrics)
s.peers[p] = ps
ps.IncRef()
return ps
}
func (s *peerScope) Peer() peer.ID {
return s.peer
}
func (s *connectionScope) PeerScope() network.PeerScope {
s.Lock()
defer s.Unlock()
// avoid nil is not nil footgun; go....
if s.peer == nil {
return nil
}
return s.peer
}
// transferAllowedToStandard transfers this connection scope from being part of
// the allowlist set of scopes to being part of the standard set of scopes.
// Happens when we first allowlisted this connection due to its IP, but later
// discovered that the peer id not what we expected.
func (s *connectionScope) transferAllowedToStandard() (err error) {
systemScope := s.rcmgr.system.resourceScope
transientScope := s.rcmgr.transient.resourceScope
stat := s.resourceScope.rc.stat()
for _, scope := range s.edges {
scope.ReleaseForChild(stat)
scope.DecRef() // removed from edges
}
s.edges = nil
if err := systemScope.ReserveForChild(stat); err != nil {
return err
}
systemScope.IncRef()
// Undo this if we fail later
defer func() {
if err != nil {
systemScope.ReleaseForChild(stat)
systemScope.DecRef()
}
}()
if err := transientScope.ReserveForChild(stat); err != nil {
return err
}
transientScope.IncRef()
// Update edges
s.edges = []*resourceScope{
systemScope,
transientScope,
}
return nil
}
func (s *connectionScope) SetPeer(p peer.ID) error {
s.Lock()
defer s.Unlock()
if s.peer != nil {
return fmt.Errorf("connection scope already attached to a peer")
}
system := s.rcmgr.system
transient := s.rcmgr.transient
if s.isAllowlisted {
system = s.rcmgr.allowlistedSystem
transient = s.rcmgr.allowlistedTransient
if !s.rcmgr.allowlist.AllowedPeerAndMultiaddr(p, s.endpoint) {
s.isAllowlisted = false
// This is not an allowed peer + multiaddr combination. We need to
// transfer this connection to the general scope. We'll do this first by
// transferring the connection to the system and transient scopes, then
// continue on with this function. The idea is that a connection
// shouldn't get the benefit of evading the transient scope because it
// was _almost_ an allowlisted connection.
if err := s.transferAllowedToStandard(); err != nil {
// Failed to transfer this connection to the standard scopes
return err
}
// set the system and transient scopes to the non-allowlisted ones
system = s.rcmgr.system
transient = s.rcmgr.transient
}
}
s.peer = s.rcmgr.getPeerScope(p)
// juggle resources from transient scope to peer scope
stat := s.resourceScope.rc.stat()
if err := s.peer.ReserveForChild(stat); err != nil {
s.peer.DecRef()
s.peer = nil
s.rcmgr.metrics.BlockPeer(p)
return err
}
transient.ReleaseForChild(stat)
transient.DecRef() // removed from edges
// update edges
edges := []*resourceScope{
s.peer.resourceScope,
system.resourceScope,
}
s.resourceScope.edges = edges
s.rcmgr.metrics.AllowPeer(p)
return nil
}
func (s *streamScope) ProtocolScope() network.ProtocolScope {
s.Lock()
defer s.Unlock()
// avoid nil is not nil footgun; go....
if s.proto == nil {
return nil
}
return s.proto
}
func (s *streamScope) SetProtocol(proto protocol.ID) error {
s.Lock()
defer s.Unlock()
if s.proto != nil {
return fmt.Errorf("stream scope already attached to a protocol")
}
s.proto = s.rcmgr.getProtocolScope(proto)
// juggle resources from transient scope to protocol scope
stat := s.resourceScope.rc.stat()
if err := s.proto.ReserveForChild(stat); err != nil {
s.proto.DecRef()
s.proto = nil
s.rcmgr.metrics.BlockProtocol(proto)
return err
}
s.peerProtoScope = s.proto.getPeerScope(s.peer.peer)
if err := s.peerProtoScope.ReserveForChild(stat); err != nil {
s.proto.ReleaseForChild(stat)
s.proto.DecRef()
s.proto = nil
s.peerProtoScope.DecRef()
s.peerProtoScope = nil
s.rcmgr.metrics.BlockProtocolPeer(proto, s.peer.peer)
return err
}
s.rcmgr.transient.ReleaseForChild(stat)
s.rcmgr.transient.DecRef() // removed from edges
// update edges
edges := []*resourceScope{
s.peer.resourceScope,
s.peerProtoScope,
s.proto.resourceScope,
s.rcmgr.system.resourceScope,
}
s.resourceScope.edges = edges
s.rcmgr.metrics.AllowProtocol(proto)
return nil
}
func (s *streamScope) ServiceScope() network.ServiceScope {
s.Lock()
defer s.Unlock()
// avoid nil is not nil footgun; go....
if s.svc == nil {
return nil
}
return s.svc
}
func (s *streamScope) SetService(svc string) error {
s.Lock()
defer s.Unlock()
if s.svc != nil {
return fmt.Errorf("stream scope already attached to a service")
}
if s.proto == nil {
return fmt.Errorf("stream scope not attached to a protocol")
}
s.svc = s.rcmgr.getServiceScope(svc)
// reserve resources in service
stat := s.resourceScope.rc.stat()
if err := s.svc.ReserveForChild(stat); err != nil {
s.svc.DecRef()
s.svc = nil
s.rcmgr.metrics.BlockService(svc)
return err
}
// get the per peer service scope constraint, if any
s.peerSvcScope = s.svc.getPeerScope(s.peer.peer)
if err := s.peerSvcScope.ReserveForChild(stat); err != nil {
s.svc.ReleaseForChild(stat)
s.svc.DecRef()
s.svc = nil
s.peerSvcScope.DecRef()
s.peerSvcScope = nil
s.rcmgr.metrics.BlockServicePeer(svc, s.peer.peer)
return err
}
// update edges
edges := []*resourceScope{
s.peer.resourceScope,
s.peerProtoScope,
s.peerSvcScope,
s.proto.resourceScope,
s.svc.resourceScope,
s.rcmgr.system.resourceScope,
}
s.resourceScope.edges = edges
s.rcmgr.metrics.AllowService(svc)
return nil
}
func (s *streamScope) PeerScope() network.PeerScope {
s.Lock()
defer s.Unlock()
// avoid nil is not nil footgun; go....
if s.peer == nil {
return nil
}
return s.peer
}

View File

@@ -0,0 +1,814 @@
package rcmgr
import (
"fmt"
"math"
"math/big"
"strings"
"sync"
"github.com/libp2p/go-libp2p/core/network"
)
// resources tracks the current state of resource consumption
type resources struct {
limit Limit
nconnsIn, nconnsOut int
nstreamsIn, nstreamsOut int
nfd int
memory int64
}
// A resourceScope can be a DAG, where a downstream node is not allowed to outlive an upstream node
// (ie cannot call Done in the upstream node before the downstream node) and account for resources
// using a linearized parent set.
// A resourceScope can be a span scope, where it has a specific owner; span scopes create a tree rooted
// at the owner (which can be a DAG scope) and can outlive their parents -- this is important because
// span scopes are the main *user* interface for memory management, and the user may call
// Done in a span scope after the system has closed the root of the span tree in some background
// goroutine.
// If we didn't make this distinction we would have a double release problem in that case.
type resourceScope struct {
sync.Mutex
done bool
refCnt int
spanID int
rc resources
owner *resourceScope // set in span scopes, which define trees
edges []*resourceScope // set in DAG scopes, it's the linearized parent set
name string // for debugging purposes
trace *trace // debug tracing
metrics *metrics // metrics collection
}
var _ network.ResourceScope = (*resourceScope)(nil)
var _ network.ResourceScopeSpan = (*resourceScope)(nil)
func newResourceScope(limit Limit, edges []*resourceScope, name string, trace *trace, metrics *metrics) *resourceScope {
for _, e := range edges {
e.IncRef()
}
r := &resourceScope{
rc: resources{limit: limit},
edges: edges,
name: name,
trace: trace,
metrics: metrics,
}
r.trace.CreateScope(name, limit)
return r
}
func newResourceScopeSpan(owner *resourceScope, id int) *resourceScope {
r := &resourceScope{
rc: resources{limit: owner.rc.limit},
owner: owner,
name: fmt.Sprintf("%s.span-%d", owner.name, id),
trace: owner.trace,
metrics: owner.metrics,
}
r.trace.CreateScope(r.name, r.rc.limit)
return r
}
// IsSpan will return true if this name was created by newResourceScopeSpan
func IsSpan(name string) bool {
return strings.Contains(name, ".span-")
}
func addInt64WithOverflow(a int64, b int64) (c int64, ok bool) {
c = a + b
return c, (c > a) == (b > 0)
}
// mulInt64WithOverflow checks for overflow in multiplying two int64s. See
// https://groups.google.com/g/golang-nuts/c/h5oSN5t3Au4/m/KaNQREhZh0QJ
func mulInt64WithOverflow(a, b int64) (c int64, ok bool) {
const mostPositive = 1<<63 - 1
const mostNegative = -(mostPositive + 1)
c = a * b
if a == 0 || b == 0 || a == 1 || b == 1 {
return c, true
}
if a == mostNegative || b == mostNegative {
return c, false
}
return c, c/b == a
}
// Resources implementation
func (rc *resources) checkMemory(rsvp int64, prio uint8) error {
if rsvp < 0 {
return fmt.Errorf("can't reserve negative memory. rsvp=%v", rsvp)
}
limit := rc.limit.GetMemoryLimit()
if limit == math.MaxInt64 {
// Special case where we've set max limits.
return nil
}
newmem, addOk := addInt64WithOverflow(rc.memory, rsvp)
threshold, mulOk := mulInt64WithOverflow(1+int64(prio), limit)
if !mulOk {
thresholdBig := big.NewInt(limit)
thresholdBig = thresholdBig.Mul(thresholdBig, big.NewInt(1+int64(prio)))
thresholdBig.Rsh(thresholdBig, 8) // Divide 256
if !thresholdBig.IsInt64() {
// Shouldn't happen since the threshold can only be <= limit
threshold = limit
}
threshold = thresholdBig.Int64()
} else {
threshold = threshold / 256
}
if !addOk || newmem > threshold {
return &ErrMemoryLimitExceeded{
current: rc.memory,
attempted: rsvp,
limit: limit,
priority: prio,
err: network.ErrResourceLimitExceeded,
}
}
return nil
}
func (rc *resources) reserveMemory(size int64, prio uint8) error {
if err := rc.checkMemory(size, prio); err != nil {
return err
}
rc.memory += size
return nil
}
func (rc *resources) releaseMemory(size int64) {
rc.memory -= size
// sanity check for bugs upstream
if rc.memory < 0 {
log.Warn("BUG: too much memory released")
rc.memory = 0
}
}
func (rc *resources) addStream(dir network.Direction) error {
if dir == network.DirInbound {
return rc.addStreams(1, 0)
}
return rc.addStreams(0, 1)
}
func (rc *resources) addStreams(incount, outcount int) error {
if incount > 0 {
limit := rc.limit.GetStreamLimit(network.DirInbound)
if rc.nstreamsIn+incount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nstreamsIn,
attempted: incount,
limit: limit,
err: fmt.Errorf("cannot reserve inbound stream: %w", network.ErrResourceLimitExceeded),
}
}
}
if outcount > 0 {
limit := rc.limit.GetStreamLimit(network.DirOutbound)
if rc.nstreamsOut+outcount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nstreamsOut,
attempted: outcount,
limit: limit,
err: fmt.Errorf("cannot reserve outbound stream: %w", network.ErrResourceLimitExceeded),
}
}
}
if limit := rc.limit.GetStreamTotalLimit(); rc.nstreamsIn+incount+rc.nstreamsOut+outcount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nstreamsIn + rc.nstreamsOut,
attempted: incount + outcount,
limit: limit,
err: fmt.Errorf("cannot reserve stream: %w", network.ErrResourceLimitExceeded),
}
}
rc.nstreamsIn += incount
rc.nstreamsOut += outcount
return nil
}
func (rc *resources) removeStream(dir network.Direction) {
if dir == network.DirInbound {
rc.removeStreams(1, 0)
} else {
rc.removeStreams(0, 1)
}
}
func (rc *resources) removeStreams(incount, outcount int) {
rc.nstreamsIn -= incount
rc.nstreamsOut -= outcount
if rc.nstreamsIn < 0 {
log.Warn("BUG: too many inbound streams released")
rc.nstreamsIn = 0
}
if rc.nstreamsOut < 0 {
log.Warn("BUG: too many outbound streams released")
rc.nstreamsOut = 0
}
}
func (rc *resources) addConn(dir network.Direction, usefd bool) error {
var fd int
if usefd {
fd = 1
}
if dir == network.DirInbound {
return rc.addConns(1, 0, fd)
}
return rc.addConns(0, 1, fd)
}
func (rc *resources) addConns(incount, outcount, fdcount int) error {
if incount > 0 {
limit := rc.limit.GetConnLimit(network.DirInbound)
if rc.nconnsIn+incount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nconnsIn,
attempted: incount,
limit: limit,
err: fmt.Errorf("cannot reserve inbound connection: %w", network.ErrResourceLimitExceeded),
}
}
}
if outcount > 0 {
limit := rc.limit.GetConnLimit(network.DirOutbound)
if rc.nconnsOut+outcount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nconnsOut,
attempted: outcount,
limit: limit,
err: fmt.Errorf("cannot reserve outbound connection: %w", network.ErrResourceLimitExceeded),
}
}
}
if connLimit := rc.limit.GetConnTotalLimit(); rc.nconnsIn+incount+rc.nconnsOut+outcount > connLimit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nconnsIn + rc.nconnsOut,
attempted: incount + outcount,
limit: connLimit,
err: fmt.Errorf("cannot reserve connection: %w", network.ErrResourceLimitExceeded),
}
}
if fdcount > 0 {
limit := rc.limit.GetFDLimit()
if rc.nfd+fdcount > limit {
return &ErrStreamOrConnLimitExceeded{
current: rc.nfd,
attempted: fdcount,
limit: limit,
err: fmt.Errorf("cannot reserve file descriptor: %w", network.ErrResourceLimitExceeded),
}
}
}
rc.nconnsIn += incount
rc.nconnsOut += outcount
rc.nfd += fdcount
return nil
}
func (rc *resources) removeConn(dir network.Direction, usefd bool) {
var fd int
if usefd {
fd = 1
}
if dir == network.DirInbound {
rc.removeConns(1, 0, fd)
} else {
rc.removeConns(0, 1, fd)
}
}
func (rc *resources) removeConns(incount, outcount, fdcount int) {
rc.nconnsIn -= incount
rc.nconnsOut -= outcount
rc.nfd -= fdcount
if rc.nconnsIn < 0 {
log.Warn("BUG: too many inbound connections released")
rc.nconnsIn = 0
}
if rc.nconnsOut < 0 {
log.Warn("BUG: too many outbound connections released")
rc.nconnsOut = 0
}
if rc.nfd < 0 {
log.Warn("BUG: too many file descriptors released")
rc.nfd = 0
}
}
func (rc *resources) stat() network.ScopeStat {
return network.ScopeStat{
Memory: rc.memory,
NumStreamsInbound: rc.nstreamsIn,
NumStreamsOutbound: rc.nstreamsOut,
NumConnsInbound: rc.nconnsIn,
NumConnsOutbound: rc.nconnsOut,
NumFD: rc.nfd,
}
}
// resourceScope implementation
func (s *resourceScope) wrapError(err error) error {
return fmt.Errorf("%s: %w", s.name, err)
}
func (s *resourceScope) ReserveMemory(size int, prio uint8) error {
s.Lock()
defer s.Unlock()
if s.done {
return s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.reserveMemory(int64(size), prio); err != nil {
log.Debugw("blocked memory reservation", logValuesMemoryLimit(s.name, "", s.rc.stat(), err)...)
s.trace.BlockReserveMemory(s.name, prio, int64(size), s.rc.memory)
s.metrics.BlockMemory(size)
return s.wrapError(err)
}
if err := s.reserveMemoryForEdges(size, prio); err != nil {
s.rc.releaseMemory(int64(size))
s.metrics.BlockMemory(size)
return s.wrapError(err)
}
s.trace.ReserveMemory(s.name, prio, int64(size), s.rc.memory)
s.metrics.AllowMemory(size)
return nil
}
func (s *resourceScope) reserveMemoryForEdges(size int, prio uint8) error {
if s.owner != nil {
return s.owner.ReserveMemory(size, prio)
}
var reserved int
var err error
for _, e := range s.edges {
var stat network.ScopeStat
stat, err = e.ReserveMemoryForChild(int64(size), prio)
if err != nil {
log.Debugw("blocked memory reservation from constraining edge", logValuesMemoryLimit(s.name, e.name, stat, err)...)
break
}
reserved++
}
if err != nil {
// we failed because of a constraint; undo memory reservations
for _, e := range s.edges[:reserved] {
e.ReleaseMemoryForChild(int64(size))
}
}
return err
}
func (s *resourceScope) releaseMemoryForEdges(size int) {
if s.owner != nil {
s.owner.ReleaseMemory(size)
return
}
for _, e := range s.edges {
e.ReleaseMemoryForChild(int64(size))
}
}
func (s *resourceScope) ReserveMemoryForChild(size int64, prio uint8) (network.ScopeStat, error) {
s.Lock()
defer s.Unlock()
if s.done {
return s.rc.stat(), s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.reserveMemory(size, prio); err != nil {
s.trace.BlockReserveMemory(s.name, prio, size, s.rc.memory)
return s.rc.stat(), s.wrapError(err)
}
s.trace.ReserveMemory(s.name, prio, size, s.rc.memory)
return network.ScopeStat{}, nil
}
func (s *resourceScope) ReleaseMemory(size int) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.releaseMemory(int64(size))
s.releaseMemoryForEdges(size)
s.trace.ReleaseMemory(s.name, int64(size), s.rc.memory)
}
func (s *resourceScope) ReleaseMemoryForChild(size int64) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.releaseMemory(size)
s.trace.ReleaseMemory(s.name, size, s.rc.memory)
}
func (s *resourceScope) AddStream(dir network.Direction) error {
s.Lock()
defer s.Unlock()
if s.done {
return s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.addStream(dir); err != nil {
log.Debugw("blocked stream", logValuesStreamLimit(s.name, "", dir, s.rc.stat(), err)...)
s.trace.BlockAddStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
return s.wrapError(err)
}
if err := s.addStreamForEdges(dir); err != nil {
s.rc.removeStream(dir)
return s.wrapError(err)
}
s.trace.AddStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
return nil
}
func (s *resourceScope) addStreamForEdges(dir network.Direction) error {
if s.owner != nil {
return s.owner.AddStream(dir)
}
var err error
var reserved int
for _, e := range s.edges {
var stat network.ScopeStat
stat, err = e.AddStreamForChild(dir)
if err != nil {
log.Debugw("blocked stream from constraining edge", logValuesStreamLimit(s.name, e.name, dir, stat, err)...)
break
}
reserved++
}
if err != nil {
for _, e := range s.edges[:reserved] {
e.RemoveStreamForChild(dir)
}
}
return err
}
func (s *resourceScope) AddStreamForChild(dir network.Direction) (network.ScopeStat, error) {
s.Lock()
defer s.Unlock()
if s.done {
return s.rc.stat(), s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.addStream(dir); err != nil {
s.trace.BlockAddStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
return s.rc.stat(), s.wrapError(err)
}
s.trace.AddStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
return network.ScopeStat{}, nil
}
func (s *resourceScope) RemoveStream(dir network.Direction) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.removeStream(dir)
s.removeStreamForEdges(dir)
s.trace.RemoveStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
}
func (s *resourceScope) removeStreamForEdges(dir network.Direction) {
if s.owner != nil {
s.owner.RemoveStream(dir)
return
}
for _, e := range s.edges {
e.RemoveStreamForChild(dir)
}
}
func (s *resourceScope) RemoveStreamForChild(dir network.Direction) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.removeStream(dir)
s.trace.RemoveStream(s.name, dir, s.rc.nstreamsIn, s.rc.nstreamsOut)
}
func (s *resourceScope) AddConn(dir network.Direction, usefd bool) error {
s.Lock()
defer s.Unlock()
if s.done {
return s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.addConn(dir, usefd); err != nil {
log.Debugw("blocked connection", logValuesConnLimit(s.name, "", dir, usefd, s.rc.stat(), err)...)
s.trace.BlockAddConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
return s.wrapError(err)
}
if err := s.addConnForEdges(dir, usefd); err != nil {
s.rc.removeConn(dir, usefd)
return s.wrapError(err)
}
s.trace.AddConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
return nil
}
func (s *resourceScope) addConnForEdges(dir network.Direction, usefd bool) error {
if s.owner != nil {
return s.owner.AddConn(dir, usefd)
}
var err error
var reserved int
for _, e := range s.edges {
var stat network.ScopeStat
stat, err = e.AddConnForChild(dir, usefd)
if err != nil {
log.Debugw("blocked connection from constraining edge", logValuesConnLimit(s.name, e.name, dir, usefd, stat, err)...)
break
}
reserved++
}
if err != nil {
for _, e := range s.edges[:reserved] {
e.RemoveConnForChild(dir, usefd)
}
}
return err
}
func (s *resourceScope) AddConnForChild(dir network.Direction, usefd bool) (network.ScopeStat, error) {
s.Lock()
defer s.Unlock()
if s.done {
return s.rc.stat(), s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.addConn(dir, usefd); err != nil {
s.trace.BlockAddConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
return s.rc.stat(), s.wrapError(err)
}
s.trace.AddConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
return network.ScopeStat{}, nil
}
func (s *resourceScope) RemoveConn(dir network.Direction, usefd bool) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.removeConn(dir, usefd)
s.removeConnForEdges(dir, usefd)
s.trace.RemoveConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
}
func (s *resourceScope) removeConnForEdges(dir network.Direction, usefd bool) {
if s.owner != nil {
s.owner.RemoveConn(dir, usefd)
}
for _, e := range s.edges {
e.RemoveConnForChild(dir, usefd)
}
}
func (s *resourceScope) RemoveConnForChild(dir network.Direction, usefd bool) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.removeConn(dir, usefd)
s.trace.RemoveConn(s.name, dir, usefd, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
}
func (s *resourceScope) ReserveForChild(st network.ScopeStat) error {
s.Lock()
defer s.Unlock()
if s.done {
return s.wrapError(network.ErrResourceScopeClosed)
}
if err := s.rc.reserveMemory(st.Memory, network.ReservationPriorityAlways); err != nil {
s.trace.BlockReserveMemory(s.name, 255, st.Memory, s.rc.memory)
return s.wrapError(err)
}
if err := s.rc.addStreams(st.NumStreamsInbound, st.NumStreamsOutbound); err != nil {
s.trace.BlockAddStreams(s.name, st.NumStreamsInbound, st.NumStreamsOutbound, s.rc.nstreamsIn, s.rc.nstreamsOut)
s.rc.releaseMemory(st.Memory)
return s.wrapError(err)
}
if err := s.rc.addConns(st.NumConnsInbound, st.NumConnsOutbound, st.NumFD); err != nil {
s.trace.BlockAddConns(s.name, st.NumConnsInbound, st.NumConnsOutbound, st.NumFD, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
s.rc.releaseMemory(st.Memory)
s.rc.removeStreams(st.NumStreamsInbound, st.NumStreamsOutbound)
return s.wrapError(err)
}
s.trace.ReserveMemory(s.name, 255, st.Memory, s.rc.memory)
s.trace.AddStreams(s.name, st.NumStreamsInbound, st.NumStreamsOutbound, s.rc.nstreamsIn, s.rc.nstreamsOut)
s.trace.AddConns(s.name, st.NumConnsInbound, st.NumConnsOutbound, st.NumFD, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
return nil
}
func (s *resourceScope) ReleaseForChild(st network.ScopeStat) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.releaseMemory(st.Memory)
s.rc.removeStreams(st.NumStreamsInbound, st.NumStreamsOutbound)
s.rc.removeConns(st.NumConnsInbound, st.NumConnsOutbound, st.NumFD)
s.trace.ReleaseMemory(s.name, st.Memory, s.rc.memory)
s.trace.RemoveStreams(s.name, st.NumStreamsInbound, st.NumStreamsOutbound, s.rc.nstreamsIn, s.rc.nstreamsOut)
s.trace.RemoveConns(s.name, st.NumConnsInbound, st.NumConnsOutbound, st.NumFD, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
}
func (s *resourceScope) ReleaseResources(st network.ScopeStat) {
s.Lock()
defer s.Unlock()
if s.done {
return
}
s.rc.releaseMemory(st.Memory)
s.rc.removeStreams(st.NumStreamsInbound, st.NumStreamsOutbound)
s.rc.removeConns(st.NumConnsInbound, st.NumConnsOutbound, st.NumFD)
if s.owner != nil {
s.owner.ReleaseResources(st)
} else {
for _, e := range s.edges {
e.ReleaseForChild(st)
}
}
s.trace.ReleaseMemory(s.name, st.Memory, s.rc.memory)
s.trace.RemoveStreams(s.name, st.NumStreamsInbound, st.NumStreamsOutbound, s.rc.nstreamsIn, s.rc.nstreamsOut)
s.trace.RemoveConns(s.name, st.NumConnsInbound, st.NumConnsOutbound, st.NumFD, s.rc.nconnsIn, s.rc.nconnsOut, s.rc.nfd)
}
func (s *resourceScope) nextSpanID() int {
s.spanID++
return s.spanID
}
func (s *resourceScope) BeginSpan() (network.ResourceScopeSpan, error) {
s.Lock()
defer s.Unlock()
if s.done {
return nil, s.wrapError(network.ErrResourceScopeClosed)
}
s.refCnt++
return newResourceScopeSpan(s, s.nextSpanID()), nil
}
func (s *resourceScope) Done() {
s.Lock()
defer s.Unlock()
if s.done {
return
}
stat := s.rc.stat()
if s.owner != nil {
s.owner.ReleaseResources(stat)
s.owner.DecRef()
} else {
for _, e := range s.edges {
e.ReleaseForChild(stat)
e.DecRef()
}
}
s.rc.nstreamsIn = 0
s.rc.nstreamsOut = 0
s.rc.nconnsIn = 0
s.rc.nconnsOut = 0
s.rc.nfd = 0
s.rc.memory = 0
s.done = true
s.trace.DestroyScope(s.name)
}
func (s *resourceScope) Stat() network.ScopeStat {
s.Lock()
defer s.Unlock()
return s.rc.stat()
}
func (s *resourceScope) IncRef() {
s.Lock()
defer s.Unlock()
s.refCnt++
}
func (s *resourceScope) DecRef() {
s.Lock()
defer s.Unlock()
s.refCnt--
}
func (s *resourceScope) IsUnused() bool {
s.Lock()
defer s.Unlock()
if s.done {
return true
}
if s.refCnt > 0 {
return false
}
st := s.rc.stat()
return st.NumStreamsInbound == 0 &&
st.NumStreamsOutbound == 0 &&
st.NumConnsInbound == 0 &&
st.NumConnsOutbound == 0 &&
st.NumFD == 0
}

View File

@@ -0,0 +1,390 @@
package rcmgr
import (
"strings"
"github.com/libp2p/go-libp2p/p2p/metricshelper"
"github.com/prometheus/client_golang/prometheus"
)
const metricNamespace = "libp2p_rcmgr"
var (
// Conns
conns = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "connections",
Help: "Number of Connections",
}, []string{"dir", "scope"})
connsInboundSystem = conns.With(prometheus.Labels{"dir": "inbound", "scope": "system"})
connsInboundTransient = conns.With(prometheus.Labels{"dir": "inbound", "scope": "transient"})
connsOutboundSystem = conns.With(prometheus.Labels{"dir": "outbound", "scope": "system"})
connsOutboundTransient = conns.With(prometheus.Labels{"dir": "outbound", "scope": "transient"})
oneTenThenExpDistributionBuckets = []float64{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256,
}
// PeerConns
peerConns = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "peer_connections",
Buckets: oneTenThenExpDistributionBuckets,
Help: "Number of connections this peer has",
}, []string{"dir"})
peerConnsInbound = peerConns.With(prometheus.Labels{"dir": "inbound"})
peerConnsOutbound = peerConns.With(prometheus.Labels{"dir": "outbound"})
// Lets us build a histogram of our current state. See https://github.com/libp2p/go-libp2p-resource-manager/pull/54#discussion_r911244757 for more information.
previousPeerConns = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "previous_peer_connections",
Buckets: oneTenThenExpDistributionBuckets,
Help: "Number of connections this peer previously had. This is used to get the current connection number per peer histogram by subtracting this from the peer_connections histogram",
}, []string{"dir"})
previousPeerConnsInbound = previousPeerConns.With(prometheus.Labels{"dir": "inbound"})
previousPeerConnsOutbound = previousPeerConns.With(prometheus.Labels{"dir": "outbound"})
// Streams
streams = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "streams",
Help: "Number of Streams",
}, []string{"dir", "scope", "protocol"})
peerStreams = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "peer_streams",
Buckets: oneTenThenExpDistributionBuckets,
Help: "Number of streams this peer has",
}, []string{"dir"})
peerStreamsInbound = peerStreams.With(prometheus.Labels{"dir": "inbound"})
peerStreamsOutbound = peerStreams.With(prometheus.Labels{"dir": "outbound"})
previousPeerStreams = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "previous_peer_streams",
Buckets: oneTenThenExpDistributionBuckets,
Help: "Number of streams this peer has",
}, []string{"dir"})
previousPeerStreamsInbound = previousPeerStreams.With(prometheus.Labels{"dir": "inbound"})
previousPeerStreamsOutbound = previousPeerStreams.With(prometheus.Labels{"dir": "outbound"})
// Memory
memoryTotal = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "memory",
Help: "Amount of memory reserved as reported to the Resource Manager",
}, []string{"scope", "protocol"})
// PeerMemory
peerMemory = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "peer_memory",
Buckets: memDistribution,
Help: "How many peers have reserved this bucket of memory, as reported to the Resource Manager",
})
previousPeerMemory = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "previous_peer_memory",
Buckets: memDistribution,
Help: "How many peers have previously reserved this bucket of memory, as reported to the Resource Manager",
})
// ConnMemory
connMemory = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "conn_memory",
Buckets: memDistribution,
Help: "How many conns have reserved this bucket of memory, as reported to the Resource Manager",
})
previousConnMemory = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "previous_conn_memory",
Buckets: memDistribution,
Help: "How many conns have previously reserved this bucket of memory, as reported to the Resource Manager",
})
// FDs
fds = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "fds",
Help: "Number of file descriptors reserved as reported to the Resource Manager",
}, []string{"scope"})
fdsSystem = fds.With(prometheus.Labels{"scope": "system"})
fdsTransient = fds.With(prometheus.Labels{"scope": "transient"})
// Blocked resources
blockedResources = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "blocked_resources",
Help: "Number of blocked resources",
}, []string{"dir", "scope", "resource"})
)
var (
memDistribution = []float64{
1 << 10, // 1KB
4 << 10, // 4KB
32 << 10, // 32KB
1 << 20, // 1MB
32 << 20, // 32MB
256 << 20, // 256MB
512 << 20, // 512MB
1 << 30, // 1GB
2 << 30, // 2GB
4 << 30, // 4GB
}
)
func MustRegisterWith(reg prometheus.Registerer) {
metricshelper.RegisterCollectors(reg,
conns,
peerConns,
previousPeerConns,
streams,
peerStreams,
previousPeerStreams,
memoryTotal,
peerMemory,
previousPeerMemory,
connMemory,
previousConnMemory,
fds,
blockedResources,
)
}
func WithMetricsDisabled() Option {
return func(r *resourceManager) error {
r.disableMetrics = true
return nil
}
}
// StatsTraceReporter reports stats on the resource manager using its traces.
type StatsTraceReporter struct{}
func NewStatsTraceReporter() (StatsTraceReporter, error) {
// TODO tell prometheus the system limits
return StatsTraceReporter{}, nil
}
func (r StatsTraceReporter) ConsumeEvent(evt TraceEvt) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
r.consumeEventWithLabelSlice(evt, tags)
}
// Separate func so that we can test that this function does not allocate. The syncPool may allocate.
func (r StatsTraceReporter) consumeEventWithLabelSlice(evt TraceEvt, tags *[]string) {
switch evt.Type {
case TraceAddStreamEvt, TraceRemoveStreamEvt:
if p := PeerStrInScopeName(evt.Name); p != "" {
// Aggregated peer stats. Counts how many peers have N number of streams open.
// Uses two buckets aggregations. One to count how many streams the
// peer has now. The other to count the negative value, or how many
// streams did the peer use to have. When looking at the data you
// take the difference from the two.
oldStreamsOut := int64(evt.StreamsOut - evt.DeltaOut)
peerStreamsOut := int64(evt.StreamsOut)
if oldStreamsOut != peerStreamsOut {
if oldStreamsOut != 0 {
previousPeerStreamsOutbound.Observe(float64(oldStreamsOut))
}
if peerStreamsOut != 0 {
peerStreamsOutbound.Observe(float64(peerStreamsOut))
}
}
oldStreamsIn := int64(evt.StreamsIn - evt.DeltaIn)
peerStreamsIn := int64(evt.StreamsIn)
if oldStreamsIn != peerStreamsIn {
if oldStreamsIn != 0 {
previousPeerStreamsInbound.Observe(float64(oldStreamsIn))
}
if peerStreamsIn != 0 {
peerStreamsInbound.Observe(float64(peerStreamsIn))
}
}
} else {
if evt.DeltaOut != 0 {
if IsSystemScope(evt.Name) || IsTransientScope(evt.Name) {
*tags = (*tags)[:0]
*tags = append(*tags, "outbound", evt.Name, "")
streams.WithLabelValues(*tags...).Set(float64(evt.StreamsOut))
} else if proto := ParseProtocolScopeName(evt.Name); proto != "" {
*tags = (*tags)[:0]
*tags = append(*tags, "outbound", "protocol", proto)
streams.WithLabelValues(*tags...).Set(float64(evt.StreamsOut))
} else {
// Not measuring service scope, connscope, servicepeer and protocolpeer. Lots of data, and
// you can use aggregated peer stats + service stats to infer
// this.
break
}
}
if evt.DeltaIn != 0 {
if IsSystemScope(evt.Name) || IsTransientScope(evt.Name) {
*tags = (*tags)[:0]
*tags = append(*tags, "inbound", evt.Name, "")
streams.WithLabelValues(*tags...).Set(float64(evt.StreamsIn))
} else if proto := ParseProtocolScopeName(evt.Name); proto != "" {
*tags = (*tags)[:0]
*tags = append(*tags, "inbound", "protocol", proto)
streams.WithLabelValues(*tags...).Set(float64(evt.StreamsIn))
} else {
// Not measuring service scope, connscope, servicepeer and protocolpeer. Lots of data, and
// you can use aggregated peer stats + service stats to infer
// this.
break
}
}
}
case TraceAddConnEvt, TraceRemoveConnEvt:
if p := PeerStrInScopeName(evt.Name); p != "" {
// Aggregated peer stats. Counts how many peers have N number of connections.
// Uses two buckets aggregations. One to count how many streams the
// peer has now. The other to count the negative value, or how many
// conns did the peer use to have. When looking at the data you
// take the difference from the two.
oldConnsOut := int64(evt.ConnsOut - evt.DeltaOut)
connsOut := int64(evt.ConnsOut)
if oldConnsOut != connsOut {
if oldConnsOut != 0 {
previousPeerConnsOutbound.Observe(float64(oldConnsOut))
}
if connsOut != 0 {
peerConnsOutbound.Observe(float64(connsOut))
}
}
oldConnsIn := int64(evt.ConnsIn - evt.DeltaIn)
connsIn := int64(evt.ConnsIn)
if oldConnsIn != connsIn {
if oldConnsIn != 0 {
previousPeerConnsInbound.Observe(float64(oldConnsIn))
}
if connsIn != 0 {
peerConnsInbound.Observe(float64(connsIn))
}
}
} else {
if IsConnScope(evt.Name) {
// Not measuring this. I don't think it's useful.
break
}
if IsSystemScope(evt.Name) {
connsInboundSystem.Set(float64(evt.ConnsIn))
connsOutboundSystem.Set(float64(evt.ConnsOut))
} else if IsTransientScope(evt.Name) {
connsInboundTransient.Set(float64(evt.ConnsIn))
connsOutboundTransient.Set(float64(evt.ConnsOut))
}
// Represents the delta in fds
if evt.Delta != 0 {
if IsSystemScope(evt.Name) {
fdsSystem.Set(float64(evt.FD))
} else if IsTransientScope(evt.Name) {
fdsTransient.Set(float64(evt.FD))
}
}
}
case TraceReserveMemoryEvt, TraceReleaseMemoryEvt:
if p := PeerStrInScopeName(evt.Name); p != "" {
oldMem := evt.Memory - evt.Delta
if oldMem != evt.Memory {
if oldMem != 0 {
previousPeerMemory.Observe(float64(oldMem))
}
if evt.Memory != 0 {
peerMemory.Observe(float64(evt.Memory))
}
}
} else if IsConnScope(evt.Name) {
oldMem := evt.Memory - evt.Delta
if oldMem != evt.Memory {
if oldMem != 0 {
previousConnMemory.Observe(float64(oldMem))
}
if evt.Memory != 0 {
connMemory.Observe(float64(evt.Memory))
}
}
} else {
if IsSystemScope(evt.Name) || IsTransientScope(evt.Name) {
*tags = (*tags)[:0]
*tags = append(*tags, evt.Name, "")
memoryTotal.WithLabelValues(*tags...).Set(float64(evt.Memory))
} else if proto := ParseProtocolScopeName(evt.Name); proto != "" {
*tags = (*tags)[:0]
*tags = append(*tags, "protocol", proto)
memoryTotal.WithLabelValues(*tags...).Set(float64(evt.Memory))
} else {
// Not measuring connscope, servicepeer and protocolpeer. Lots of data, and
// you can use aggregated peer stats + service stats to infer
// this.
break
}
}
case TraceBlockAddConnEvt, TraceBlockAddStreamEvt, TraceBlockReserveMemoryEvt:
var resource string
if evt.Type == TraceBlockAddConnEvt {
resource = "connection"
} else if evt.Type == TraceBlockAddStreamEvt {
resource = "stream"
} else {
resource = "memory"
}
scopeName := evt.Name
// Only the top scopeName. We don't want to get the peerid here.
// Using indexes and slices to avoid allocating.
scopeSplitIdx := strings.IndexByte(scopeName, ':')
if scopeSplitIdx != -1 {
scopeName = evt.Name[0:scopeSplitIdx]
}
// Drop the connection or stream id
idSplitIdx := strings.IndexByte(scopeName, '-')
if idSplitIdx != -1 {
scopeName = scopeName[0:idSplitIdx]
}
if evt.DeltaIn != 0 {
*tags = (*tags)[:0]
*tags = append(*tags, "inbound", scopeName, resource)
blockedResources.WithLabelValues(*tags...).Add(float64(evt.DeltaIn))
}
if evt.DeltaOut != 0 {
*tags = (*tags)[:0]
*tags = append(*tags, "outbound", scopeName, resource)
blockedResources.WithLabelValues(*tags...).Add(float64(evt.DeltaOut))
}
if evt.Delta != 0 && resource == "connection" {
// This represents fds blocked
*tags = (*tags)[:0]
*tags = append(*tags, "", scopeName, "fd")
blockedResources.WithLabelValues(*tags...).Add(float64(evt.Delta))
} else if evt.Delta != 0 {
*tags = (*tags)[:0]
*tags = append(*tags, "", scopeName, resource)
blockedResources.WithLabelValues(*tags...).Add(float64(evt.Delta))
}
}
}

View File

@@ -0,0 +1,11 @@
//go:build !linux && !darwin && !windows
package rcmgr
import "runtime"
// TODO: figure out how to get the number of file descriptors on Windows and other systems
func getNumFDs() int {
log.Warnf("cannot determine number of file descriptors on %s", runtime.GOOS)
return 0
}

View File

@@ -0,0 +1,16 @@
//go:build linux || darwin
package rcmgr
import (
"golang.org/x/sys/unix"
)
func getNumFDs() int {
var l unix.Rlimit
if err := unix.Getrlimit(unix.RLIMIT_NOFILE, &l); err != nil {
log.Errorw("failed to get fd limit", "error", err)
return 0
}
return int(l.Cur)
}

View File

@@ -0,0 +1,11 @@
//go:build windows
package rcmgr
import (
"math"
)
func getNumFDs() int {
return math.MaxInt
}

View File

@@ -0,0 +1,698 @@
package rcmgr
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
)
type trace struct {
path string
ctx context.Context
cancel func()
wg sync.WaitGroup
mx sync.Mutex
done bool
pendingWrites []interface{}
reporters []TraceReporter
}
type TraceReporter interface {
// ConsumeEvent consumes a trace event. This is called synchronously,
// implementations should process the event quickly.
ConsumeEvent(TraceEvt)
}
func WithTrace(path string) Option {
return func(r *resourceManager) error {
if r.trace == nil {
r.trace = &trace{path: path}
} else {
r.trace.path = path
}
return nil
}
}
func WithTraceReporter(reporter TraceReporter) Option {
return func(r *resourceManager) error {
if r.trace == nil {
r.trace = &trace{}
}
r.trace.reporters = append(r.trace.reporters, reporter)
return nil
}
}
type TraceEvtTyp string
const (
TraceStartEvt TraceEvtTyp = "start"
TraceCreateScopeEvt TraceEvtTyp = "create_scope"
TraceDestroyScopeEvt TraceEvtTyp = "destroy_scope"
TraceReserveMemoryEvt TraceEvtTyp = "reserve_memory"
TraceBlockReserveMemoryEvt TraceEvtTyp = "block_reserve_memory"
TraceReleaseMemoryEvt TraceEvtTyp = "release_memory"
TraceAddStreamEvt TraceEvtTyp = "add_stream"
TraceBlockAddStreamEvt TraceEvtTyp = "block_add_stream"
TraceRemoveStreamEvt TraceEvtTyp = "remove_stream"
TraceAddConnEvt TraceEvtTyp = "add_conn"
TraceBlockAddConnEvt TraceEvtTyp = "block_add_conn"
TraceRemoveConnEvt TraceEvtTyp = "remove_conn"
)
type scopeClass struct {
name string
}
func (s scopeClass) MarshalJSON() ([]byte, error) {
name := s.name
var span string
if idx := strings.Index(name, "span:"); idx > -1 {
name = name[:idx-1]
span = name[idx+5:]
}
// System and Transient scope
if name == "system" || name == "transient" || name == "allowlistedSystem" || name == "allowlistedTransient" {
return json.Marshal(struct {
Class string
Span string `json:",omitempty"`
}{
Class: name,
Span: span,
})
}
// Connection scope
if strings.HasPrefix(name, "conn-") {
return json.Marshal(struct {
Class string
Conn string
Span string `json:",omitempty"`
}{
Class: "conn",
Conn: name[5:],
Span: span,
})
}
// Stream scope
if strings.HasPrefix(name, "stream-") {
return json.Marshal(struct {
Class string
Stream string
Span string `json:",omitempty"`
}{
Class: "stream",
Stream: name[7:],
Span: span,
})
}
// Peer scope
if strings.HasPrefix(name, "peer:") {
return json.Marshal(struct {
Class string
Peer string
Span string `json:",omitempty"`
}{
Class: "peer",
Peer: name[5:],
Span: span,
})
}
if strings.HasPrefix(name, "service:") {
if idx := strings.Index(name, "peer:"); idx > 0 { // Peer-Service scope
return json.Marshal(struct {
Class string
Service string
Peer string
Span string `json:",omitempty"`
}{
Class: "service-peer",
Service: name[8 : idx-1],
Peer: name[idx+5:],
Span: span,
})
} else { // Service scope
return json.Marshal(struct {
Class string
Service string
Span string `json:",omitempty"`
}{
Class: "service",
Service: name[8:],
Span: span,
})
}
}
if strings.HasPrefix(name, "protocol:") {
if idx := strings.Index(name, "peer:"); idx > -1 { // Peer-Protocol scope
return json.Marshal(struct {
Class string
Protocol string
Peer string
Span string `json:",omitempty"`
}{
Class: "protocol-peer",
Protocol: name[9 : idx-1],
Peer: name[idx+5:],
Span: span,
})
} else { // Protocol scope
return json.Marshal(struct {
Class string
Protocol string
Span string `json:",omitempty"`
}{
Class: "protocol",
Protocol: name[9:],
Span: span,
})
}
}
return nil, fmt.Errorf("unrecognized scope: %s", name)
}
type TraceEvt struct {
Time string
Type TraceEvtTyp
Scope *scopeClass `json:",omitempty"`
Name string `json:",omitempty"`
Limit interface{} `json:",omitempty"`
Priority uint8 `json:",omitempty"`
Delta int64 `json:",omitempty"`
DeltaIn int `json:",omitempty"`
DeltaOut int `json:",omitempty"`
Memory int64 `json:",omitempty"`
StreamsIn int `json:",omitempty"`
StreamsOut int `json:",omitempty"`
ConnsIn int `json:",omitempty"`
ConnsOut int `json:",omitempty"`
FD int `json:",omitempty"`
}
func (t *trace) push(evt TraceEvt) {
t.mx.Lock()
defer t.mx.Unlock()
if t.done {
return
}
evt.Time = time.Now().Format(time.RFC3339Nano)
if evt.Name != "" {
evt.Scope = &scopeClass{name: evt.Name}
}
for _, reporter := range t.reporters {
reporter.ConsumeEvent(evt)
}
if t.path != "" {
t.pendingWrites = append(t.pendingWrites, evt)
}
}
func (t *trace) backgroundWriter(out io.WriteCloser) {
defer t.wg.Done()
defer out.Close()
gzOut := gzip.NewWriter(out)
defer gzOut.Close()
jsonOut := json.NewEncoder(gzOut)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var pend []interface{}
getEvents := func() {
t.mx.Lock()
tmp := t.pendingWrites
t.pendingWrites = pend[:0]
pend = tmp
t.mx.Unlock()
}
for {
select {
case <-ticker.C:
getEvents()
if len(pend) == 0 {
continue
}
if err := t.writeEvents(pend, jsonOut); err != nil {
log.Warnf("error writing rcmgr trace: %s", err)
t.mx.Lock()
t.done = true
t.mx.Unlock()
return
}
if err := gzOut.Flush(); err != nil {
log.Warnf("error flushing rcmgr trace: %s", err)
t.mx.Lock()
t.done = true
t.mx.Unlock()
return
}
case <-t.ctx.Done():
getEvents()
if len(pend) == 0 {
return
}
if err := t.writeEvents(pend, jsonOut); err != nil {
log.Warnf("error writing rcmgr trace: %s", err)
return
}
if err := gzOut.Flush(); err != nil {
log.Warnf("error flushing rcmgr trace: %s", err)
}
return
}
}
}
func (t *trace) writeEvents(pend []interface{}, jout *json.Encoder) error {
for _, e := range pend {
if err := jout.Encode(e); err != nil {
return err
}
}
return nil
}
func (t *trace) Start(limits Limiter) error {
if t == nil {
return nil
}
t.ctx, t.cancel = context.WithCancel(context.Background())
if t.path != "" {
out, err := os.OpenFile(t.path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return nil
}
t.wg.Add(1)
go t.backgroundWriter(out)
}
t.push(TraceEvt{
Type: TraceStartEvt,
Limit: limits,
})
return nil
}
func (t *trace) Close() error {
if t == nil {
return nil
}
t.mx.Lock()
if t.done {
t.mx.Unlock()
return nil
}
t.cancel()
t.done = true
t.mx.Unlock()
t.wg.Wait()
return nil
}
func (t *trace) CreateScope(scope string, limit Limit) {
if t == nil {
return
}
t.push(TraceEvt{
Type: TraceCreateScopeEvt,
Name: scope,
Limit: limit,
})
}
func (t *trace) DestroyScope(scope string) {
if t == nil {
return
}
t.push(TraceEvt{
Type: TraceDestroyScopeEvt,
Name: scope,
})
}
func (t *trace) ReserveMemory(scope string, prio uint8, size, mem int64) {
if t == nil {
return
}
if size == 0 {
return
}
t.push(TraceEvt{
Type: TraceReserveMemoryEvt,
Name: scope,
Priority: prio,
Delta: size,
Memory: mem,
})
}
func (t *trace) BlockReserveMemory(scope string, prio uint8, size, mem int64) {
if t == nil {
return
}
if size == 0 {
return
}
t.push(TraceEvt{
Type: TraceBlockReserveMemoryEvt,
Name: scope,
Priority: prio,
Delta: size,
Memory: mem,
})
}
func (t *trace) ReleaseMemory(scope string, size, mem int64) {
if t == nil {
return
}
if size == 0 {
return
}
t.push(TraceEvt{
Type: TraceReleaseMemoryEvt,
Name: scope,
Delta: -size,
Memory: mem,
})
}
func (t *trace) AddStream(scope string, dir network.Direction, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
var deltaIn, deltaOut int
if dir == network.DirInbound {
deltaIn = 1
} else {
deltaOut = 1
}
t.push(TraceEvt{
Type: TraceAddStreamEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) BlockAddStream(scope string, dir network.Direction, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
var deltaIn, deltaOut int
if dir == network.DirInbound {
deltaIn = 1
} else {
deltaOut = 1
}
t.push(TraceEvt{
Type: TraceBlockAddStreamEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) RemoveStream(scope string, dir network.Direction, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
var deltaIn, deltaOut int
if dir == network.DirInbound {
deltaIn = -1
} else {
deltaOut = -1
}
t.push(TraceEvt{
Type: TraceRemoveStreamEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) AddStreams(scope string, deltaIn, deltaOut, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 {
return
}
t.push(TraceEvt{
Type: TraceAddStreamEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) BlockAddStreams(scope string, deltaIn, deltaOut, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 {
return
}
t.push(TraceEvt{
Type: TraceBlockAddStreamEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) RemoveStreams(scope string, deltaIn, deltaOut, nstreamsIn, nstreamsOut int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 {
return
}
t.push(TraceEvt{
Type: TraceRemoveStreamEvt,
Name: scope,
DeltaIn: -deltaIn,
DeltaOut: -deltaOut,
StreamsIn: nstreamsIn,
StreamsOut: nstreamsOut,
})
}
func (t *trace) AddConn(scope string, dir network.Direction, usefd bool, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
var deltaIn, deltaOut, deltafd int
if dir == network.DirInbound {
deltaIn = 1
} else {
deltaOut = 1
}
if usefd {
deltafd = 1
}
t.push(TraceEvt{
Type: TraceAddConnEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
Delta: int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}
func (t *trace) BlockAddConn(scope string, dir network.Direction, usefd bool, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
var deltaIn, deltaOut, deltafd int
if dir == network.DirInbound {
deltaIn = 1
} else {
deltaOut = 1
}
if usefd {
deltafd = 1
}
t.push(TraceEvt{
Type: TraceBlockAddConnEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
Delta: int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}
func (t *trace) RemoveConn(scope string, dir network.Direction, usefd bool, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
var deltaIn, deltaOut, deltafd int
if dir == network.DirInbound {
deltaIn = -1
} else {
deltaOut = -1
}
if usefd {
deltafd = -1
}
t.push(TraceEvt{
Type: TraceRemoveConnEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
Delta: int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}
func (t *trace) AddConns(scope string, deltaIn, deltaOut, deltafd, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 && deltafd == 0 {
return
}
t.push(TraceEvt{
Type: TraceAddConnEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
Delta: int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}
func (t *trace) BlockAddConns(scope string, deltaIn, deltaOut, deltafd, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 && deltafd == 0 {
return
}
t.push(TraceEvt{
Type: TraceBlockAddConnEvt,
Name: scope,
DeltaIn: deltaIn,
DeltaOut: deltaOut,
Delta: int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}
func (t *trace) RemoveConns(scope string, deltaIn, deltaOut, deltafd, nconnsIn, nconnsOut, nfd int) {
if t == nil {
return
}
if deltaIn == 0 && deltaOut == 0 && deltafd == 0 {
return
}
t.push(TraceEvt{
Type: TraceRemoveConnEvt,
Name: scope,
DeltaIn: -deltaIn,
DeltaOut: -deltaOut,
Delta: -int64(deltafd),
ConnsIn: nconnsIn,
ConnsOut: nconnsOut,
FD: nfd,
})
}

View File

@@ -0,0 +1,222 @@
package routedhost
import (
"context"
"fmt"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
)
var log = logging.Logger("routedhost")
// AddressTTL is the expiry time for our addresses.
// We expire them quickly.
const AddressTTL = time.Second * 10
// RoutedHost is a p2p Host that includes a routing system.
// This allows the Host to find the addresses for peers when
// it does not have them.
type RoutedHost struct {
host host.Host // embedded other host.
route Routing
}
type Routing interface {
FindPeer(context.Context, peer.ID) (peer.AddrInfo, error)
}
func Wrap(h host.Host, r Routing) *RoutedHost {
return &RoutedHost{h, r}
}
// Connect ensures there is a connection between this host and the peer with
// given peer.ID. See (host.Host).Connect for more information.
//
// RoutedHost's Connect differs in that if the host has no addresses for a
// given peer, it will use its routing system to try to find some.
func (rh *RoutedHost) Connect(ctx context.Context, pi peer.AddrInfo) error {
// first, check if we're already connected unless force direct dial.
forceDirect, _ := network.GetForceDirectDial(ctx)
if !forceDirect {
if rh.Network().Connectedness(pi.ID) == network.Connected {
return nil
}
}
// if we were given some addresses, keep + use them.
if len(pi.Addrs) > 0 {
rh.Peerstore().AddAddrs(pi.ID, pi.Addrs, peerstore.TempAddrTTL)
}
// Check if we have some addresses in our recent memory.
addrs := rh.Peerstore().Addrs(pi.ID)
if len(addrs) < 1 {
// no addrs? find some with the routing system.
var err error
addrs, err = rh.findPeerAddrs(ctx, pi.ID)
if err != nil {
return err
}
}
// Issue 448: if our address set includes routed specific relay addrs,
// we need to make sure the relay's addr itself is in the peerstore or else
// we won't be able to dial it.
for _, addr := range addrs {
if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err != nil {
// not a relay address
continue
}
if addr.Protocols()[0].Code != ma.P_P2P {
// not a routed relay specific address
continue
}
relay, _ := addr.ValueForProtocol(ma.P_P2P)
relayID, err := peer.Decode(relay)
if err != nil {
log.Debugf("failed to parse relay ID in address %s: %s", relay, err)
continue
}
if len(rh.Peerstore().Addrs(relayID)) > 0 {
// we already have addrs for this relay
continue
}
relayAddrs, err := rh.findPeerAddrs(ctx, relayID)
if err != nil {
log.Debugf("failed to find relay %s: %s", relay, err)
continue
}
rh.Peerstore().AddAddrs(relayID, relayAddrs, peerstore.TempAddrTTL)
}
// if we're here, we got some addrs. let's use our wrapped host to connect.
pi.Addrs = addrs
if cerr := rh.host.Connect(ctx, pi); cerr != nil {
// We couldn't connect. Let's check if we have the most
// up-to-date addresses for the given peer. If there
// are addresses we didn't know about previously, we
// try to connect again.
newAddrs, err := rh.findPeerAddrs(ctx, pi.ID)
if err != nil {
log.Debugf("failed to find more peer addresses %s: %s", pi.ID, err)
return cerr
}
// Build lookup map
lookup := make(map[string]struct{}, len(addrs))
for _, addr := range addrs {
lookup[string(addr.Bytes())] = struct{}{}
}
// if there's any address that's not in the previous set
// of addresses, try to connect again. If all addresses
// where known previously we return the original error.
for _, newAddr := range newAddrs {
if _, found := lookup[string(newAddr.Bytes())]; found {
continue
}
pi.Addrs = newAddrs
return rh.host.Connect(ctx, pi)
}
// No appropriate new address found.
// Return the original dial error.
return cerr
}
return nil
}
func (rh *RoutedHost) findPeerAddrs(ctx context.Context, id peer.ID) ([]ma.Multiaddr, error) {
pi, err := rh.route.FindPeer(ctx, id)
if err != nil {
return nil, err // couldnt find any :(
}
if pi.ID != id {
err = fmt.Errorf("routing failure: provided addrs for different peer")
log.Errorw("got wrong peer",
"error", err,
"wantedPeer", id,
"gotPeer", pi.ID,
)
return nil, err
}
return pi.Addrs, nil
}
func (rh *RoutedHost) ID() peer.ID {
return rh.host.ID()
}
func (rh *RoutedHost) Peerstore() peerstore.Peerstore {
return rh.host.Peerstore()
}
func (rh *RoutedHost) Addrs() []ma.Multiaddr {
return rh.host.Addrs()
}
func (rh *RoutedHost) Network() network.Network {
return rh.host.Network()
}
func (rh *RoutedHost) Mux() protocol.Switch {
return rh.host.Mux()
}
func (rh *RoutedHost) EventBus() event.Bus {
return rh.host.EventBus()
}
func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
rh.host.SetStreamHandler(pid, handler)
}
func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
rh.host.SetStreamHandlerMatch(pid, m, handler)
}
func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) {
rh.host.RemoveStreamHandler(pid)
}
func (rh *RoutedHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) {
// Ensure we have a connection, with peer addresses resolved by the routing system (#207)
// It is not sufficient to let the underlying host connect, it will most likely not have
// any addresses for the peer without any prior connections.
// If the caller wants to prevent the host from dialing, it should use the NoDial option.
if nodial, _ := network.GetNoDial(ctx); !nodial {
err := rh.Connect(ctx, peer.AddrInfo{ID: p})
if err != nil {
return nil, err
}
}
return rh.host.NewStream(ctx, p, pids...)
}
func (rh *RoutedHost) Close() error {
// no need to close IpfsRouting. we dont own it.
return rh.host.Close()
}
func (rh *RoutedHost) ConnManager() connmgr.ConnManager {
return rh.host.ConnManager()
}
var _ (host.Host) = (*RoutedHost)(nil)

View File

@@ -0,0 +1,29 @@
package metricshelper
import ma "github.com/multiformats/go-multiaddr"
var transports = [...]int{ma.P_CIRCUIT, ma.P_WEBRTC, ma.P_WEBTRANSPORT, ma.P_QUIC, ma.P_QUIC_V1, ma.P_WSS, ma.P_WS, ma.P_TCP}
func GetTransport(a ma.Multiaddr) string {
for _, t := range transports {
if _, err := a.ValueForProtocol(t); err == nil {
return ma.ProtocolWithCode(t).Name
}
}
return "other"
}
func GetIPVersion(addr ma.Multiaddr) string {
version := "unknown"
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_IP4 {
version = "ip4"
return false
} else if c.Protocol().Code == ma.P_IP6 {
version = "ip6"
return false
}
return true
})
return version
}

View File

@@ -0,0 +1,14 @@
package metricshelper
import "github.com/libp2p/go-libp2p/core/network"
func GetDirection(dir network.Direction) string {
switch dir {
case network.DirOutbound:
return "outbound"
case network.DirInbound:
return "inbound"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,26 @@
package metricshelper
import (
"fmt"
"sync"
)
const capacity = 8
var stringPool = sync.Pool{New: func() any {
s := make([]string, 0, capacity)
return &s
}}
func GetStringSlice() *[]string {
s := stringPool.Get().(*[]string)
*s = (*s)[:0]
return s
}
func PutStringSlice(s *[]string) {
if c := cap(*s); c < capacity {
panic(fmt.Sprintf("expected a string slice with capacity 8 or greater, got %d", c))
}
stringPool.Put(s)
}

View File

@@ -0,0 +1,20 @@
package metricshelper
import (
"errors"
"github.com/prometheus/client_golang/prometheus"
)
// RegisterCollectors registers the collectors with reg ignoring
// reregistration error and panics on any other error
func RegisterCollectors(reg prometheus.Registerer, collectors ...prometheus.Collector) {
for _, c := range collectors {
err := reg.Register(c)
if err != nil {
if ok := errors.As(err, &prometheus.AlreadyRegisteredError{}); !ok {
panic(err)
}
}
}
}

View File

@@ -0,0 +1,48 @@
package mplex
import (
"context"
"github.com/libp2p/go-libp2p/core/network"
mp "github.com/libp2p/go-mplex"
)
type conn mp.Multiplex
var _ network.MuxedConn = &conn{}
// NewMuxedConn constructs a new Conn from a *mp.Multiplex.
func NewMuxedConn(m *mp.Multiplex) network.MuxedConn {
return (*conn)(m)
}
func (c *conn) Close() error {
return c.mplex().Close()
}
func (c *conn) IsClosed() bool {
return c.mplex().IsClosed()
}
// OpenStream creates a new stream.
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
s, err := c.mplex().NewStream(ctx)
if err != nil {
return nil, err
}
return (*stream)(s), nil
}
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
s, err := c.mplex().Accept()
if err != nil {
return nil, err
}
return (*stream)(s), nil
}
func (c *conn) mplex() *mp.Multiplex {
return (*mp.Multiplex)(c)
}

View File

@@ -0,0 +1,64 @@
package mplex
import (
"time"
"github.com/libp2p/go-libp2p/core/network"
mp "github.com/libp2p/go-mplex"
)
// stream implements network.MuxedStream over mplex.Stream.
type stream mp.Stream
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.mplex().Read(b)
if err == mp.ErrStreamReset {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.mplex().Write(b)
if err == mp.ErrStreamReset {
err = network.ErrReset
}
return n, err
}
func (s *stream) Close() error {
return s.mplex().Close()
}
func (s *stream) CloseWrite() error {
return s.mplex().CloseWrite()
}
func (s *stream) CloseRead() error {
return s.mplex().CloseRead()
}
func (s *stream) Reset() error {
return s.mplex().Reset()
}
func (s *stream) SetDeadline(t time.Time) error {
return s.mplex().SetDeadline(t)
}
func (s *stream) SetReadDeadline(t time.Time) error {
return s.mplex().SetReadDeadline(t)
}
func (s *stream) SetWriteDeadline(t time.Time) error {
return s.mplex().SetWriteDeadline(t)
}
func (s *stream) mplex() *mp.Stream {
return (*mp.Stream)(s)
}

View File

@@ -0,0 +1,28 @@
package mplex
import (
"net"
"github.com/libp2p/go-libp2p/core/network"
mp "github.com/libp2p/go-mplex"
)
// DefaultTransport has default settings for Transport
var DefaultTransport = &Transport{}
const ID = "/mplex/6.7.0"
var _ network.Multiplexer = &Transport{}
// Transport implements mux.Multiplexer that constructs
// mplex-backed muxed connections.
type Transport struct{}
func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) {
m, err := mp.NewMultiplex(nc, isServer, scope)
if err != nil {
return nil, err
}
return NewMuxedConn(m), nil
}

View File

@@ -0,0 +1,49 @@
package yamux
import (
"context"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
)
// conn implements mux.MuxedConn over yamux.Session.
type conn yamux.Session
var _ network.MuxedConn = &conn{}
// NewMuxedConn constructs a new MuxedConn from a yamux.Session.
func NewMuxedConn(m *yamux.Session) network.MuxedConn {
return (*conn)(m)
}
// Close closes underlying yamux
func (c *conn) Close() error {
return c.yamux().Close()
}
// IsClosed checks if yamux.Session is in closed state.
func (c *conn) IsClosed() bool {
return c.yamux().IsClosed()
}
// OpenStream creates a new stream.
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
s, err := c.yamux().OpenStream(ctx)
if err != nil {
return nil, err
}
return (*stream)(s), nil
}
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
s, err := c.yamux().AcceptStream()
return (*stream)(s), err
}
func (c *conn) yamux() *yamux.Session {
return (*yamux.Session)(c)
}

View File

@@ -0,0 +1,64 @@
package yamux
import (
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
)
// stream implements mux.MuxedStream over yamux.Stream.
type stream yamux.Stream
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.yamux().Read(b)
if err == yamux.ErrStreamReset {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.yamux().Write(b)
if err == yamux.ErrStreamReset {
err = network.ErrReset
}
return n, err
}
func (s *stream) Close() error {
return s.yamux().Close()
}
func (s *stream) Reset() error {
return s.yamux().Reset()
}
func (s *stream) CloseRead() error {
return s.yamux().CloseRead()
}
func (s *stream) CloseWrite() error {
return s.yamux().CloseWrite()
}
func (s *stream) SetDeadline(t time.Time) error {
return s.yamux().SetDeadline(t)
}
func (s *stream) SetReadDeadline(t time.Time) error {
return s.yamux().SetReadDeadline(t)
}
func (s *stream) SetWriteDeadline(t time.Time) error {
return s.yamux().SetWriteDeadline(t)
}
func (s *stream) yamux() *yamux.Stream {
return (*yamux.Stream)(s)
}

View File

@@ -0,0 +1,63 @@
package yamux
import (
"io"
"math"
"net"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
)
var DefaultTransport *Transport
const ID = "/yamux/1.0.0"
func init() {
config := yamux.DefaultConfig()
// We've bumped this to 16MiB as this critically limits throughput.
//
// 1MiB means a best case of 10MiB/s (83.89Mbps) on a connection with
// 100ms latency. The default gave us 2.4MiB *best case* which was
// totally unacceptable.
config.MaxStreamWindowSize = uint32(16 * 1024 * 1024)
// don't spam
config.LogOutput = io.Discard
// We always run over a security transport that buffers internally
// (i.e., uses a block cipher).
config.ReadBufSize = 0
// Effectively disable the incoming streams limit.
// This is now dynamically limited by the resource manager.
config.MaxIncomingStreams = math.MaxUint32
DefaultTransport = (*Transport)(config)
}
// Transport implements mux.Multiplexer that constructs
// yamux-backed muxed connections.
type Transport yamux.Config
var _ network.Multiplexer = &Transport{}
func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) {
var newSpan func() (yamux.MemoryManager, error)
if scope != nil {
newSpan = func() (yamux.MemoryManager, error) { return scope.BeginSpan() }
}
var s *yamux.Session
var err error
if isServer {
s, err = yamux.Server(nc, t.Config(), newSpan)
} else {
s, err = yamux.Client(nc, t.Config(), newSpan)
}
if err != nil {
return nil, err
}
return NewMuxedConn(s), nil
}
func (t *Transport) Config() *yamux.Config {
return (*yamux.Config)(t)
}

View File

@@ -0,0 +1,725 @@
package connmgr
import (
"context"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
)
var log = logging.Logger("connmgr")
// BasicConnMgr is a ConnManager that trims connections whenever the count exceeds the
// high watermark. New connections are given a grace period before they're subject
// to trimming. Trims are automatically run on demand, only if the time from the
// previous trim is higher than 10 seconds. Furthermore, trims can be explicitly
// requested through the public interface of this struct (see TrimOpenConns).
//
// See configuration parameters in NewConnManager.
type BasicConnMgr struct {
*decayer
clock clock.Clock
cfg *config
segments segments
plk sync.RWMutex
protected map[peer.ID]map[string]struct{}
// channel-based semaphore that enforces only a single trim is in progress
trimMutex sync.Mutex
connCount atomic.Int32
// to be accessed atomically. This is mimicking the implementation of a sync.Once.
// Take care of correct alignment when modifying this struct.
trimCount uint64
lastTrimMu sync.RWMutex
lastTrim time.Time
refCount sync.WaitGroup
ctx context.Context
cancel func()
unregisterMemoryWatcher func()
}
var (
_ connmgr.ConnManager = (*BasicConnMgr)(nil)
_ connmgr.Decayer = (*BasicConnMgr)(nil)
)
type segment struct {
sync.Mutex
peers map[peer.ID]*peerInfo
}
type segments struct {
// bucketsMu is used to prevent deadlocks when concurrent processes try to
// grab multiple segment locks at once. If you need multiple segment locks
// at once, you should grab this lock first. You may release this lock once
// you have the segment locks.
bucketsMu sync.Mutex
buckets [256]*segment
}
func (ss *segments) get(p peer.ID) *segment {
return ss.buckets[byte(p[len(p)-1])]
}
func (ss *segments) countPeers() (count int) {
for _, seg := range ss.buckets {
seg.Lock()
count += len(seg.peers)
seg.Unlock()
}
return count
}
func (s *segment) tagInfoFor(p peer.ID, now time.Time) *peerInfo {
pi, ok := s.peers[p]
if ok {
return pi
}
// create a temporary peer to buffer early tags before the Connected notification arrives.
pi = &peerInfo{
id: p,
firstSeen: now, // this timestamp will be updated when the first Connected notification arrives.
temp: true,
tags: make(map[string]int),
decaying: make(map[*decayingTag]*connmgr.DecayingValue),
conns: make(map[network.Conn]time.Time),
}
s.peers[p] = pi
return pi
}
// NewConnManager creates a new BasicConnMgr with the provided params:
// lo and hi are watermarks governing the number of connections that'll be maintained.
// When the peer count exceeds the 'high watermark', as many peers will be pruned (and
// their connections terminated) until 'low watermark' peers remain.
func NewConnManager(low, hi int, opts ...Option) (*BasicConnMgr, error) {
cfg := &config{
highWater: hi,
lowWater: low,
gracePeriod: time.Minute,
silencePeriod: 10 * time.Second,
clock: clock.New(),
}
for _, o := range opts {
if err := o(cfg); err != nil {
return nil, err
}
}
if cfg.decayer == nil {
// Set the default decayer config.
cfg.decayer = (&DecayerCfg{}).WithDefaults()
}
cm := &BasicConnMgr{
cfg: cfg,
clock: cfg.clock,
protected: make(map[peer.ID]map[string]struct{}, 16),
segments: segments{},
}
for i := range cm.segments.buckets {
cm.segments.buckets[i] = &segment{
peers: make(map[peer.ID]*peerInfo),
}
}
cm.ctx, cm.cancel = context.WithCancel(context.Background())
if cfg.emergencyTrim {
// When we're running low on memory, immediately trigger a trim.
cm.unregisterMemoryWatcher = registerWatchdog(cm.memoryEmergency)
}
decay, _ := NewDecayer(cfg.decayer, cm)
cm.decayer = decay
cm.refCount.Add(1)
go cm.background()
return cm, nil
}
// memoryEmergency is run when we run low on memory.
// Close connections until we right the low watermark.
// We don't pay attention to the silence period or the grace period.
// We try to not kill protected connections, but if that turns out to be necessary, not connection is safe!
func (cm *BasicConnMgr) memoryEmergency() {
connCount := int(cm.connCount.Load())
target := connCount - cm.cfg.lowWater
if target < 0 {
log.Warnw("Low on memory, but we only have a few connections", "num", connCount, "low watermark", cm.cfg.lowWater)
return
} else {
log.Warnf("Low on memory. Closing %d connections.", target)
}
cm.trimMutex.Lock()
defer atomic.AddUint64(&cm.trimCount, 1)
defer cm.trimMutex.Unlock()
// Trim connections without paying attention to the silence period.
for _, c := range cm.getConnsToCloseEmergency(target) {
log.Infow("low on memory. closing conn", "peer", c.RemotePeer())
c.Close()
}
// finally, update the last trim time.
cm.lastTrimMu.Lock()
cm.lastTrim = cm.clock.Now()
cm.lastTrimMu.Unlock()
}
func (cm *BasicConnMgr) Close() error {
cm.cancel()
if cm.unregisterMemoryWatcher != nil {
cm.unregisterMemoryWatcher()
}
if err := cm.decayer.Close(); err != nil {
return err
}
cm.refCount.Wait()
return nil
}
func (cm *BasicConnMgr) Protect(id peer.ID, tag string) {
cm.plk.Lock()
defer cm.plk.Unlock()
tags, ok := cm.protected[id]
if !ok {
tags = make(map[string]struct{}, 2)
cm.protected[id] = tags
}
tags[tag] = struct{}{}
}
func (cm *BasicConnMgr) Unprotect(id peer.ID, tag string) (protected bool) {
cm.plk.Lock()
defer cm.plk.Unlock()
tags, ok := cm.protected[id]
if !ok {
return false
}
if delete(tags, tag); len(tags) == 0 {
delete(cm.protected, id)
return false
}
return true
}
func (cm *BasicConnMgr) IsProtected(id peer.ID, tag string) (protected bool) {
cm.plk.Lock()
defer cm.plk.Unlock()
tags, ok := cm.protected[id]
if !ok {
return false
}
if tag == "" {
return true
}
_, protected = tags[tag]
return protected
}
// peerInfo stores metadata for a given peer.
type peerInfo struct {
id peer.ID
tags map[string]int // value for each tag
decaying map[*decayingTag]*connmgr.DecayingValue // decaying tags
value int // cached sum of all tag values
temp bool // this is a temporary entry holding early tags, and awaiting connections
conns map[network.Conn]time.Time // start time of each connection
firstSeen time.Time // timestamp when we began tracking this peer.
}
type peerInfos []*peerInfo
// SortByValueAndStreams sorts peerInfos by their value and stream count. It
// will sort peers with no streams before those with streams (all else being
// equal). If `sortByMoreStreams` is true it will sort peers with more streams
// before those with fewer streams. This is useful to prioritize freeing memory.
func (p peerInfos) SortByValueAndStreams(segments *segments, sortByMoreStreams bool) {
sort.Slice(p, func(i, j int) bool {
left, right := p[i], p[j]
// Grab this lock so that we can grab both segment locks below without deadlocking.
segments.bucketsMu.Lock()
// lock this to protect from concurrent modifications from connect/disconnect events
leftSegment := segments.get(left.id)
leftSegment.Lock()
defer leftSegment.Unlock()
rightSegment := segments.get(right.id)
if leftSegment != rightSegment {
// These two peers are not in the same segment, lets get the lock
rightSegment.Lock()
defer rightSegment.Unlock()
}
segments.bucketsMu.Unlock()
// temporary peers are preferred for pruning.
if left.temp != right.temp {
return left.temp
}
// otherwise, compare by value.
if left.value != right.value {
return left.value < right.value
}
incomingAndStreams := func(m map[network.Conn]time.Time) (incoming bool, numStreams int) {
for c := range m {
stat := c.Stat()
if stat.Direction == network.DirInbound {
incoming = true
}
numStreams += stat.NumStreams
}
return
}
leftIncoming, leftStreams := incomingAndStreams(left.conns)
rightIncoming, rightStreams := incomingAndStreams(right.conns)
// prefer closing inactive connections (no streams open)
if rightStreams != leftStreams && (leftStreams == 0 || rightStreams == 0) {
return leftStreams < rightStreams
}
// incoming connections are preferred for pruning
if leftIncoming != rightIncoming {
return leftIncoming
}
if sortByMoreStreams {
// prune connections with a higher number of streams first
return rightStreams < leftStreams
} else {
return leftStreams < rightStreams
}
})
}
// TrimOpenConns closes the connections of as many peers as needed to make the peer count
// equal the low watermark. Peers are sorted in ascending order based on their total value,
// pruning those peers with the lowest scores first, as long as they are not within their
// grace period.
//
// This function blocks until a trim is completed. If a trim is underway, a new
// one won't be started, and instead it'll wait until that one is completed before
// returning.
func (cm *BasicConnMgr) TrimOpenConns(_ context.Context) {
// TODO: error return value so we can cleanly signal we are aborting because:
// (a) there's another trim in progress, or (b) the silence period is in effect.
cm.doTrim()
}
func (cm *BasicConnMgr) background() {
defer cm.refCount.Done()
interval := cm.cfg.gracePeriod / 2
if cm.cfg.silencePeriod != 0 {
interval = cm.cfg.silencePeriod
}
ticker := cm.clock.Ticker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if cm.connCount.Load() < int32(cm.cfg.highWater) {
// Below high water, skip.
continue
}
case <-cm.ctx.Done():
return
}
cm.trim()
}
}
func (cm *BasicConnMgr) doTrim() {
// This logic is mimicking the implementation of sync.Once in the standard library.
count := atomic.LoadUint64(&cm.trimCount)
cm.trimMutex.Lock()
defer cm.trimMutex.Unlock()
if count == atomic.LoadUint64(&cm.trimCount) {
cm.trim()
cm.lastTrimMu.Lock()
cm.lastTrim = cm.clock.Now()
cm.lastTrimMu.Unlock()
atomic.AddUint64(&cm.trimCount, 1)
}
}
// trim starts the trim, if the last trim happened before the configured silence period.
func (cm *BasicConnMgr) trim() {
// do the actual trim.
for _, c := range cm.getConnsToClose() {
log.Debugw("closing conn", "peer", c.RemotePeer())
c.Close()
}
}
func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn {
candidates := make(peerInfos, 0, cm.segments.countPeers())
cm.plk.RLock()
for _, s := range cm.segments.buckets {
s.Lock()
for id, inf := range s.peers {
if _, ok := cm.protected[id]; ok {
// skip over protected peer.
continue
}
candidates = append(candidates, inf)
}
s.Unlock()
}
cm.plk.RUnlock()
// Sort peers according to their value.
candidates.SortByValueAndStreams(&cm.segments, true)
selected := make([]network.Conn, 0, target+10)
for _, inf := range candidates {
if target <= 0 {
break
}
s := cm.segments.get(inf.id)
s.Lock()
for c := range inf.conns {
selected = append(selected, c)
}
target -= len(inf.conns)
s.Unlock()
}
if len(selected) >= target {
// We found enough connections that were not protected.
return selected
}
// We didn't find enough unprotected connections.
// We have no choice but to kill some protected connections.
candidates = candidates[:0]
cm.plk.RLock()
for _, s := range cm.segments.buckets {
s.Lock()
for _, inf := range s.peers {
candidates = append(candidates, inf)
}
s.Unlock()
}
cm.plk.RUnlock()
candidates.SortByValueAndStreams(&cm.segments, true)
for _, inf := range candidates {
if target <= 0 {
break
}
// lock this to protect from concurrent modifications from connect/disconnect events
s := cm.segments.get(inf.id)
s.Lock()
for c := range inf.conns {
selected = append(selected, c)
}
target -= len(inf.conns)
s.Unlock()
}
return selected
}
// getConnsToClose runs the heuristics described in TrimOpenConns and returns the
// connections to close.
func (cm *BasicConnMgr) getConnsToClose() []network.Conn {
if cm.cfg.lowWater == 0 || cm.cfg.highWater == 0 {
// disabled
return nil
}
if int(cm.connCount.Load()) <= cm.cfg.lowWater {
log.Info("open connection count below limit")
return nil
}
candidates := make(peerInfos, 0, cm.segments.countPeers())
var ncandidates int
gracePeriodStart := cm.clock.Now().Add(-cm.cfg.gracePeriod)
cm.plk.RLock()
for _, s := range cm.segments.buckets {
s.Lock()
for id, inf := range s.peers {
if _, ok := cm.protected[id]; ok {
// skip over protected peer.
continue
}
if inf.firstSeen.After(gracePeriodStart) {
// skip peers in the grace period.
continue
}
// note that we're copying the entry here,
// but since inf.conns is a map, it will still point to the original object
candidates = append(candidates, inf)
ncandidates += len(inf.conns)
}
s.Unlock()
}
cm.plk.RUnlock()
if ncandidates < cm.cfg.lowWater {
log.Info("open connection count above limit but too many are in the grace period")
// We have too many connections but fewer than lowWater
// connections out of the grace period.
//
// If we trimmed now, we'd kill potentially useful connections.
return nil
}
// Sort peers according to their value.
candidates.SortByValueAndStreams(&cm.segments, false)
target := ncandidates - cm.cfg.lowWater
// slightly overallocate because we may have more than one conns per peer
selected := make([]network.Conn, 0, target+10)
for _, inf := range candidates {
if target <= 0 {
break
}
// lock this to protect from concurrent modifications from connect/disconnect events
s := cm.segments.get(inf.id)
s.Lock()
if len(inf.conns) == 0 && inf.temp {
// handle temporary entries for early tags -- this entry has gone past the grace period
// and still holds no connections, so prune it.
delete(s.peers, inf.id)
} else {
for c := range inf.conns {
selected = append(selected, c)
}
target -= len(inf.conns)
}
s.Unlock()
}
return selected
}
// GetTagInfo is called to fetch the tag information associated with a given
// peer, nil is returned if p refers to an unknown peer.
func (cm *BasicConnMgr) GetTagInfo(p peer.ID) *connmgr.TagInfo {
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
pi, ok := s.peers[p]
if !ok {
return nil
}
out := &connmgr.TagInfo{
FirstSeen: pi.firstSeen,
Value: pi.value,
Tags: make(map[string]int),
Conns: make(map[string]time.Time),
}
for t, v := range pi.tags {
out.Tags[t] = v
}
for t, v := range pi.decaying {
out.Tags[t.name] = v.Value
}
for c, t := range pi.conns {
out.Conns[c.RemoteMultiaddr().String()] = t
}
return out
}
// TagPeer is called to associate a string and integer with a given peer.
func (cm *BasicConnMgr) TagPeer(p peer.ID, tag string, val int) {
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
pi := s.tagInfoFor(p, cm.clock.Now())
// Update the total value of the peer.
pi.value += val - pi.tags[tag]
pi.tags[tag] = val
}
// UntagPeer is called to disassociate a string and integer from a given peer.
func (cm *BasicConnMgr) UntagPeer(p peer.ID, tag string) {
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
pi, ok := s.peers[p]
if !ok {
log.Info("tried to remove tag from untracked peer: ", p)
return
}
// Update the total value of the peer.
pi.value -= pi.tags[tag]
delete(pi.tags, tag)
}
// UpsertTag is called to insert/update a peer tag
func (cm *BasicConnMgr) UpsertTag(p peer.ID, tag string, upsert func(int) int) {
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
pi := s.tagInfoFor(p, cm.clock.Now())
oldval := pi.tags[tag]
newval := upsert(oldval)
pi.value += newval - oldval
pi.tags[tag] = newval
}
// CMInfo holds the configuration for BasicConnMgr, as well as status data.
type CMInfo struct {
// The low watermark, as described in NewConnManager.
LowWater int
// The high watermark, as described in NewConnManager.
HighWater int
// The timestamp when the last trim was triggered.
LastTrim time.Time
// The configured grace period, as described in NewConnManager.
GracePeriod time.Duration
// The current connection count.
ConnCount int
}
// GetInfo returns the configuration and status data for this connection manager.
func (cm *BasicConnMgr) GetInfo() CMInfo {
cm.lastTrimMu.RLock()
lastTrim := cm.lastTrim
cm.lastTrimMu.RUnlock()
return CMInfo{
HighWater: cm.cfg.highWater,
LowWater: cm.cfg.lowWater,
LastTrim: lastTrim,
GracePeriod: cm.cfg.gracePeriod,
ConnCount: int(cm.connCount.Load()),
}
}
// Notifee returns a sink through which Notifiers can inform the BasicConnMgr when
// events occur. Currently, the notifee only reacts upon connection events
// {Connected, Disconnected}.
func (cm *BasicConnMgr) Notifee() network.Notifiee {
return (*cmNotifee)(cm)
}
type cmNotifee BasicConnMgr
func (nn *cmNotifee) cm() *BasicConnMgr {
return (*BasicConnMgr)(nn)
}
// Connected is called by notifiers to inform that a new connection has been established.
// The notifee updates the BasicConnMgr to start tracking the connection. If the new connection
// count exceeds the high watermark, a trim may be triggered.
func (nn *cmNotifee) Connected(n network.Network, c network.Conn) {
cm := nn.cm()
p := c.RemotePeer()
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
id := c.RemotePeer()
pinfo, ok := s.peers[id]
if !ok {
pinfo = &peerInfo{
id: id,
firstSeen: cm.clock.Now(),
tags: make(map[string]int),
decaying: make(map[*decayingTag]*connmgr.DecayingValue),
conns: make(map[network.Conn]time.Time),
}
s.peers[id] = pinfo
} else if pinfo.temp {
// we had created a temporary entry for this peer to buffer early tags before the
// Connected notification arrived: flip the temporary flag, and update the firstSeen
// timestamp to the real one.
pinfo.temp = false
pinfo.firstSeen = cm.clock.Now()
}
_, ok = pinfo.conns[c]
if ok {
log.Error("received connected notification for conn we are already tracking: ", p)
return
}
pinfo.conns[c] = cm.clock.Now()
cm.connCount.Add(1)
}
// Disconnected is called by notifiers to inform that an existing connection has been closed or terminated.
// The notifee updates the BasicConnMgr accordingly to stop tracking the connection, and performs housekeeping.
func (nn *cmNotifee) Disconnected(n network.Network, c network.Conn) {
cm := nn.cm()
p := c.RemotePeer()
s := cm.segments.get(p)
s.Lock()
defer s.Unlock()
cinf, ok := s.peers[p]
if !ok {
log.Error("received disconnected notification for peer we are not tracking: ", p)
return
}
_, ok = cinf.conns[c]
if !ok {
log.Error("received disconnected notification for conn we are not tracking: ", p)
return
}
delete(cinf.conns, c)
if len(cinf.conns) == 0 {
delete(s.peers, p)
}
cm.connCount.Add(-1)
}
// Listen is no-op in this implementation.
func (nn *cmNotifee) Listen(n network.Network, addr ma.Multiaddr) {}
// ListenClose is no-op in this implementation.
func (nn *cmNotifee) ListenClose(n network.Network, addr ma.Multiaddr) {}

View File

@@ -0,0 +1,356 @@
package connmgr
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/benbjohnson/clock"
)
// DefaultResolution is the default resolution of the decay tracker.
var DefaultResolution = 1 * time.Minute
// bumpCmd represents a bump command.
type bumpCmd struct {
peer peer.ID
tag *decayingTag
delta int
}
// removeCmd represents a tag removal command.
type removeCmd struct {
peer peer.ID
tag *decayingTag
}
// decayer tracks and manages all decaying tags and their values.
type decayer struct {
cfg *DecayerCfg
mgr *BasicConnMgr
clock clock.Clock // for testing.
tagsMu sync.Mutex
knownTags map[string]*decayingTag
// lastTick stores the last time the decayer ticked. Guarded by atomic.
lastTick atomic.Pointer[time.Time]
// bumpTagCh queues bump commands to be processed by the loop.
bumpTagCh chan bumpCmd
removeTagCh chan removeCmd
closeTagCh chan *decayingTag
// closure thingies.
closeCh chan struct{}
doneCh chan struct{}
err error
}
var _ connmgr.Decayer = (*decayer)(nil)
// DecayerCfg is the configuration object for the Decayer.
type DecayerCfg struct {
Resolution time.Duration
Clock clock.Clock
}
// WithDefaults writes the default values on this DecayerConfig instance,
// and returns itself for chainability.
//
// cfg := (&DecayerCfg{}).WithDefaults()
// cfg.Resolution = 30 * time.Second
// t := NewDecayer(cfg, cm)
func (cfg *DecayerCfg) WithDefaults() *DecayerCfg {
cfg.Resolution = DefaultResolution
return cfg
}
// NewDecayer creates a new decaying tag registry.
func NewDecayer(cfg *DecayerCfg, mgr *BasicConnMgr) (*decayer, error) {
// use real time if the Clock in the config is nil.
if cfg.Clock == nil {
cfg.Clock = clock.New()
}
d := &decayer{
cfg: cfg,
mgr: mgr,
clock: cfg.Clock,
knownTags: make(map[string]*decayingTag),
bumpTagCh: make(chan bumpCmd, 128),
removeTagCh: make(chan removeCmd, 128),
closeTagCh: make(chan *decayingTag, 128),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
now := d.clock.Now()
d.lastTick.Store(&now)
// kick things off.
go d.process()
return d, nil
}
func (d *decayer) RegisterDecayingTag(name string, interval time.Duration, decayFn connmgr.DecayFn, bumpFn connmgr.BumpFn) (connmgr.DecayingTag, error) {
d.tagsMu.Lock()
defer d.tagsMu.Unlock()
if _, ok := d.knownTags[name]; ok {
return nil, fmt.Errorf("decaying tag with name %s already exists", name)
}
if interval < d.cfg.Resolution {
log.Warnf("decay interval for %s (%s) was lower than tracker's resolution (%s); overridden to resolution",
name, interval, d.cfg.Resolution)
interval = d.cfg.Resolution
}
if interval%d.cfg.Resolution != 0 {
log.Warnf("decay interval for tag %s (%s) is not a multiple of tracker's resolution (%s); "+
"some precision may be lost", name, interval, d.cfg.Resolution)
}
lastTick := d.lastTick.Load()
tag := &decayingTag{
trkr: d,
name: name,
interval: interval,
nextTick: lastTick.Add(interval),
decayFn: decayFn,
bumpFn: bumpFn,
}
d.knownTags[name] = tag
return tag, nil
}
// Close closes the Decayer. It is idempotent.
func (d *decayer) Close() error {
select {
case <-d.doneCh:
return d.err
default:
}
close(d.closeCh)
<-d.doneCh
return d.err
}
// process is the heart of the tracker. It performs the following duties:
//
// 1. Manages decay.
// 2. Applies score bumps.
// 3. Yields when closed.
func (d *decayer) process() {
defer close(d.doneCh)
ticker := d.clock.Ticker(d.cfg.Resolution)
defer ticker.Stop()
var (
bmp bumpCmd
visit = make(map[*decayingTag]struct{})
)
for {
select {
case <-ticker.C:
now := d.clock.Now()
d.lastTick.Store(&now)
d.tagsMu.Lock()
for _, tag := range d.knownTags {
if tag.nextTick.After(now) {
// skip the tag.
continue
}
// Mark the tag to be updated in this round.
visit[tag] = struct{}{}
}
d.tagsMu.Unlock()
// Visit each peer, and decay tags that need to be decayed.
for _, s := range d.mgr.segments.buckets {
s.Lock()
// Entered a segment that contains peers. Process each peer.
for _, p := range s.peers {
for tag, v := range p.decaying {
if _, ok := visit[tag]; !ok {
// skip this tag.
continue
}
// ~ this value needs to be visited. ~
var delta int
if after, rm := tag.decayFn(*v); rm {
// delete the value and move on to the next tag.
delta -= v.Value
delete(p.decaying, tag)
} else {
// accumulate the delta, and apply the changes.
delta += after - v.Value
v.Value, v.LastVisit = after, now
}
p.value += delta
}
}
s.Unlock()
}
// Reset each tag's next visit round, and clear the visited set.
for tag := range visit {
tag.nextTick = tag.nextTick.Add(tag.interval)
delete(visit, tag)
}
case bmp = <-d.bumpTagCh:
var (
now = d.clock.Now()
peer, tag = bmp.peer, bmp.tag
)
s := d.mgr.segments.get(peer)
s.Lock()
p := s.tagInfoFor(peer, d.clock.Now())
v, ok := p.decaying[tag]
if !ok {
v = &connmgr.DecayingValue{
Tag: tag,
Peer: peer,
LastVisit: now,
Added: now,
Value: 0,
}
p.decaying[tag] = v
}
prev := v.Value
v.Value, v.LastVisit = v.Tag.(*decayingTag).bumpFn(*v, bmp.delta), now
p.value += v.Value - prev
s.Unlock()
case rm := <-d.removeTagCh:
s := d.mgr.segments.get(rm.peer)
s.Lock()
p := s.tagInfoFor(rm.peer, d.clock.Now())
v, ok := p.decaying[rm.tag]
if !ok {
s.Unlock()
continue
}
p.value -= v.Value
delete(p.decaying, rm.tag)
s.Unlock()
case t := <-d.closeTagCh:
// Stop tracking the tag.
d.tagsMu.Lock()
delete(d.knownTags, t.name)
d.tagsMu.Unlock()
// Remove the tag from all peers that had it in the connmgr.
for _, s := range d.mgr.segments.buckets {
// visit all segments, and attempt to remove the tag from all the peers it stores.
s.Lock()
for _, p := range s.peers {
if dt, ok := p.decaying[t]; ok {
// decrease the value of the tagInfo, and delete the tag.
p.value -= dt.Value
delete(p.decaying, t)
}
}
s.Unlock()
}
case <-d.closeCh:
return
}
}
}
// decayingTag represents a decaying tag, with an associated decay interval, a
// decay function, and a bump function.
type decayingTag struct {
trkr *decayer
name string
interval time.Duration
nextTick time.Time
decayFn connmgr.DecayFn
bumpFn connmgr.BumpFn
// closed marks this tag as closed, so that if it's bumped after being
// closed, we can return an error.
closed atomic.Bool
}
var _ connmgr.DecayingTag = (*decayingTag)(nil)
func (t *decayingTag) Name() string {
return t.name
}
func (t *decayingTag) Interval() time.Duration {
return t.interval
}
// Bump bumps a tag for this peer.
func (t *decayingTag) Bump(p peer.ID, delta int) error {
if t.closed.Load() {
return fmt.Errorf("decaying tag %s had been closed; no further bumps are accepted", t.name)
}
bmp := bumpCmd{peer: p, tag: t, delta: delta}
select {
case t.trkr.bumpTagCh <- bmp:
return nil
default:
return fmt.Errorf(
"unable to bump decaying tag for peer %s, tag %s, delta %d; queue full (len=%d)",
p.Pretty(), t.name, delta, len(t.trkr.bumpTagCh))
}
}
func (t *decayingTag) Remove(p peer.ID) error {
if t.closed.Load() {
return fmt.Errorf("decaying tag %s had been closed; no further removals are accepted", t.name)
}
rm := removeCmd{peer: p, tag: t}
select {
case t.trkr.removeTagCh <- rm:
return nil
default:
return fmt.Errorf(
"unable to remove decaying tag for peer %s, tag %s; queue full (len=%d)",
p.Pretty(), t.name, len(t.trkr.removeTagCh))
}
}
func (t *decayingTag) Close() error {
if !t.closed.CompareAndSwap(false, true) {
log.Warnf("duplicate decaying tag closure: %s; skipping", t.name)
return nil
}
select {
case t.trkr.closeTagCh <- t:
return nil
default:
return fmt.Errorf("unable to close decaying tag %s; queue full (len=%d)", t.name, len(t.trkr.closeTagCh))
}
}

View File

@@ -0,0 +1,64 @@
package connmgr
import (
"errors"
"time"
"github.com/benbjohnson/clock"
)
// config is the configuration struct for the basic connection manager.
type config struct {
highWater int
lowWater int
gracePeriod time.Duration
silencePeriod time.Duration
decayer *DecayerCfg
emergencyTrim bool
clock clock.Clock
}
// Option represents an option for the basic connection manager.
type Option func(*config) error
// DecayerConfig applies a configuration for the decayer.
func DecayerConfig(opts *DecayerCfg) Option {
return func(cfg *config) error {
cfg.decayer = opts
return nil
}
}
// WithClock sets the internal clock impl
func WithClock(c clock.Clock) Option {
return func(cfg *config) error {
cfg.clock = c
return nil
}
}
// WithGracePeriod sets the grace period.
// The grace period is the time a newly opened connection is given before it becomes
// subject to pruning.
func WithGracePeriod(p time.Duration) Option {
return func(cfg *config) error {
if p < 0 {
return errors.New("grace period must be non-negative")
}
cfg.gracePeriod = p
return nil
}
}
// WithSilencePeriod sets the silence period.
// The connection manager will perform a cleanup once per silence period
// if the number of connections surpasses the high watermark.
func WithSilencePeriod(p time.Duration) Option {
return func(cfg *config) error {
if p <= 0 {
return errors.New("silence period must be non-zero")
}
cfg.silencePeriod = p
return nil
}
}

View File

@@ -0,0 +1,17 @@
//go:build cgo && !nowatchdog
package connmgr
import "github.com/raulk/go-watchdog"
func registerWatchdog(cb func()) (unregister func()) {
return watchdog.RegisterPostGCNotifee(cb)
}
// WithEmergencyTrim is an option to enable trimming connections on memory emergency.
func WithEmergencyTrim(enable bool) Option {
return func(cfg *config) error {
cfg.emergencyTrim = enable
return nil
}
}

View File

@@ -0,0 +1,15 @@
//go:build !cgo || nowatchdog
package connmgr
func registerWatchdog(func()) (unregister func()) {
return nil
}
// WithEmergencyTrim is an option to enable trimming connections on memory emergency.
func WithEmergencyTrim(enable bool) Option {
return func(cfg *config) error {
log.Warn("platform doesn't support go-watchdog")
return nil
}
}

257
vendor/github.com/libp2p/go-libp2p/p2p/net/nat/nat.go generated vendored Normal file
View File

@@ -0,0 +1,257 @@
package nat
import (
"context"
"errors"
"fmt"
"net/netip"
"sync"
"time"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-nat"
)
// ErrNoMapping signals no mapping exists for an address
var ErrNoMapping = errors.New("mapping not established")
var log = logging.Logger("nat")
// MappingDuration is a default port mapping duration.
// Port mappings are renewed every (MappingDuration / 3)
const MappingDuration = time.Minute
// CacheTime is the time a mapping will cache an external address for
const CacheTime = 15 * time.Second
type entry struct {
protocol string
port int
}
// so we can mock it in tests
var discoverGateway = nat.DiscoverGateway
// DiscoverNAT looks for a NAT device in the network and returns an object that can manage port mappings.
func DiscoverNAT(ctx context.Context) (*NAT, error) {
natInstance, err := discoverGateway(ctx)
if err != nil {
return nil, err
}
var extAddr netip.Addr
extIP, err := natInstance.GetExternalAddress()
if err == nil {
extAddr, _ = netip.AddrFromSlice(extIP)
}
// Log the device addr.
addr, err := natInstance.GetDeviceAddress()
if err != nil {
log.Debug("DiscoverGateway address error:", err)
} else {
log.Debug("DiscoverGateway address:", addr)
}
ctx, cancel := context.WithCancel(context.Background())
nat := &NAT{
nat: natInstance,
extAddr: extAddr,
mappings: make(map[entry]int),
ctx: ctx,
ctxCancel: cancel,
}
nat.refCount.Add(1)
go func() {
defer nat.refCount.Done()
nat.background()
}()
return nat, nil
}
// NAT is an object that manages address port mappings in
// NATs (Network Address Translators). It is a long-running
// service that will periodically renew port mappings,
// and keep an up-to-date list of all the external addresses.
type NAT struct {
natmu sync.Mutex
nat nat.NAT
// External IP of the NAT. Will be renewed periodically (every CacheTime).
extAddr netip.Addr
refCount sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
mappingmu sync.RWMutex // guards mappings
closed bool
mappings map[entry]int
}
// Close shuts down all port mappings. NAT can no longer be used.
func (nat *NAT) Close() error {
nat.mappingmu.Lock()
nat.closed = true
nat.mappingmu.Unlock()
nat.ctxCancel()
nat.refCount.Wait()
return nil
}
func (nat *NAT) GetMapping(protocol string, port int) (addr netip.AddrPort, found bool) {
nat.mappingmu.Lock()
defer nat.mappingmu.Unlock()
if !nat.extAddr.IsValid() {
return netip.AddrPort{}, false
}
extPort, found := nat.mappings[entry{protocol: protocol, port: port}]
if !found {
return netip.AddrPort{}, false
}
return netip.AddrPortFrom(nat.extAddr, uint16(extPort)), true
}
// AddMapping attempts to construct a mapping on protocol and internal port.
// It blocks until a mapping was established. Once added, it periodically renews the mapping.
//
// May not succeed, and mappings may change over time;
// NAT devices may not respect our port requests, and even lie.
func (nat *NAT) AddMapping(ctx context.Context, protocol string, port int) error {
switch protocol {
case "tcp", "udp":
default:
return fmt.Errorf("invalid protocol: %s", protocol)
}
nat.mappingmu.Lock()
defer nat.mappingmu.Unlock()
if nat.closed {
return errors.New("closed")
}
// do it once synchronously, so first mapping is done right away, and before exiting,
// allowing users -- in the optimistic case -- to use results right after.
extPort := nat.establishMapping(ctx, protocol, port)
nat.mappings[entry{protocol: protocol, port: port}] = extPort
return nil
}
// RemoveMapping removes a port mapping.
// It blocks until the NAT has removed the mapping.
func (nat *NAT) RemoveMapping(ctx context.Context, protocol string, port int) error {
nat.mappingmu.Lock()
defer nat.mappingmu.Unlock()
switch protocol {
case "tcp", "udp":
e := entry{protocol: protocol, port: port}
if _, ok := nat.mappings[e]; ok {
delete(nat.mappings, e)
return nat.nat.DeletePortMapping(ctx, protocol, port)
}
return errors.New("unknown mapping")
default:
return fmt.Errorf("invalid protocol: %s", protocol)
}
}
func (nat *NAT) background() {
const mappingUpdate = MappingDuration / 3
now := time.Now()
nextMappingUpdate := now.Add(mappingUpdate)
nextAddrUpdate := now.Add(CacheTime)
t := time.NewTimer(minTime(nextMappingUpdate, nextAddrUpdate).Sub(now)) // don't use a ticker here. We don't know how long establishing the mappings takes.
defer t.Stop()
var in []entry
var out []int // port numbers
for {
select {
case now := <-t.C:
if now.After(nextMappingUpdate) {
in = in[:0]
out = out[:0]
nat.mappingmu.Lock()
for e := range nat.mappings {
in = append(in, e)
}
nat.mappingmu.Unlock()
// Establishing the mapping involves network requests.
// Don't hold the mutex, just save the ports.
for _, e := range in {
out = append(out, nat.establishMapping(nat.ctx, e.protocol, e.port))
}
nat.mappingmu.Lock()
for i, p := range in {
if _, ok := nat.mappings[p]; !ok {
continue // entry might have been deleted
}
nat.mappings[p] = out[i]
}
nat.mappingmu.Unlock()
nextMappingUpdate = time.Now().Add(mappingUpdate)
}
if now.After(nextAddrUpdate) {
var extAddr netip.Addr
extIP, err := nat.nat.GetExternalAddress()
if err == nil {
extAddr, _ = netip.AddrFromSlice(extIP)
}
nat.extAddr = extAddr
nextAddrUpdate = time.Now().Add(CacheTime)
}
t.Reset(time.Until(minTime(nextAddrUpdate, nextMappingUpdate)))
case <-nat.ctx.Done():
nat.mappingmu.Lock()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for e := range nat.mappings {
delete(nat.mappings, e)
nat.nat.DeletePortMapping(ctx, e.protocol, e.port)
}
nat.mappingmu.Unlock()
return
}
}
}
func (nat *NAT) establishMapping(ctx context.Context, protocol string, internalPort int) (externalPort int) {
log.Debugf("Attempting port map: %s/%d", protocol, internalPort)
const comment = "libp2p"
nat.natmu.Lock()
var err error
externalPort, err = nat.nat.AddPortMapping(ctx, protocol, internalPort, comment, MappingDuration)
if err != nil {
// Some hardware does not support mappings with timeout, so try that
externalPort, err = nat.nat.AddPortMapping(ctx, protocol, internalPort, comment, 0)
}
nat.natmu.Unlock()
if err != nil || externalPort == 0 {
// TODO: log.Event
if err != nil {
log.Warnf("failed to establish port mapping: %s", err)
} else {
log.Warnf("failed to establish port mapping: newport = 0")
}
// we do not close if the mapping failed,
// because it may work again next time.
return 0
}
log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol)
return externalPort
}
func minTime(a, b time.Time) time.Time {
if a.Before(b) {
return a
}
return b
}

View File

@@ -0,0 +1,18 @@
package pnet
import (
"errors"
"net"
ipnet "github.com/libp2p/go-libp2p/core/pnet"
)
// NewProtectedConn creates a new protected connection
func NewProtectedConn(psk ipnet.PSK, conn net.Conn) (net.Conn, error) {
if len(psk) != 32 {
return nil, errors.New("expected 32 byte PSK")
}
var p [32]byte
copy(p[:], psk)
return newPSKConn(&p, conn)
}

View File

@@ -0,0 +1,83 @@
package pnet
import (
"crypto/cipher"
"crypto/rand"
"io"
"net"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/davidlazar/go-crypto/salsa20"
pool "github.com/libp2p/go-buffer-pool"
)
// we are using buffer pool as user needs their slice back
// so we can't do XOR cripter in place
var (
errShortNonce = pnet.NewError("could not read full nonce")
errInsecureNil = pnet.NewError("insecure is nil")
errPSKNil = pnet.NewError("pre-shread key is nil")
)
type pskConn struct {
net.Conn
psk *[32]byte
writeS20 cipher.Stream
readS20 cipher.Stream
}
func (c *pskConn) Read(out []byte) (int, error) {
if c.readS20 == nil {
nonce := make([]byte, 24)
_, err := io.ReadFull(c.Conn, nonce)
if err != nil {
return 0, errShortNonce
}
c.readS20 = salsa20.New(c.psk, nonce)
}
n, err := c.Conn.Read(out) // read to in
if n > 0 {
c.readS20.XORKeyStream(out[:n], out[:n]) // decrypt to out buffer
}
return n, err
}
func (c *pskConn) Write(in []byte) (int, error) {
if c.writeS20 == nil {
nonce := make([]byte, 24)
_, err := rand.Read(nonce)
if err != nil {
return 0, err
}
_, err = c.Conn.Write(nonce)
if err != nil {
return 0, err
}
c.writeS20 = salsa20.New(c.psk, nonce)
}
out := pool.Get(len(in))
defer pool.Put(out)
c.writeS20.XORKeyStream(out, in) // encrypt
return c.Conn.Write(out) // send
}
var _ net.Conn = (*pskConn)(nil)
func newPSKConn(psk *[32]byte, insecure net.Conn) (net.Conn, error) {
if insecure == nil {
return nil, errInsecureNil
}
if psk == nil {
return nil, errPSKNil
}
return &pskConn{
Conn: insecure,
psk: psk,
}, nil
}

View File

@@ -0,0 +1,62 @@
package reuseport
import (
"context"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// Dial dials the given multiaddr, reusing ports we're currently listening on if
// possible.
//
// Dial attempts to be smart about choosing the source port. For example, If
// we're dialing a loopback address and we're listening on one or more loopback
// ports, Dial will randomly choose one of the loopback ports and addresses and
// reuse it.
func (t *Transport) Dial(raddr ma.Multiaddr) (manet.Conn, error) {
return t.DialContext(context.Background(), raddr)
}
// DialContext is like Dial but takes a context.
func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
network, addr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
var d *dialer
switch network {
case "tcp4":
d = t.v4.getDialer(network)
case "tcp6":
d = t.v6.getDialer(network)
default:
return nil, ErrWrongProto
}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
maconn, err := manet.WrapNetConn(conn)
if err != nil {
conn.Close()
return nil, err
}
return maconn, nil
}
func (n *network) getDialer(network string) *dialer {
n.mu.RLock()
d := n.dialer
n.mu.RUnlock()
if d == nil {
n.mu.Lock()
defer n.mu.Unlock()
if n.dialer == nil {
n.dialer = newDialer(n.listeners)
}
d = n.dialer
}
return d
}

View File

@@ -0,0 +1,114 @@
package reuseport
import (
"context"
"fmt"
"math/rand"
"net"
"github.com/libp2p/go-netroute"
)
type dialer struct {
// All address that are _not_ loopback or unspecified (0.0.0.0 or ::).
specific []*net.TCPAddr
// All loopback addresses (127.*.*.*, ::1).
loopback []*net.TCPAddr
// Unspecified addresses (0.0.0.0, ::)
unspecified []*net.TCPAddr
}
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func randAddr(addrs []*net.TCPAddr) *net.TCPAddr {
if len(addrs) > 0 {
return addrs[rand.Intn(len(addrs))]
}
return nil
}
// DialContext dials a target addr.
//
// In-order:
//
// 1. If we're _explicitly_ listening on the prefered source address for the destination address
// (per the system's routes), we'll use that listener's port as the source port.
// 2. If we're listening on one or more _unspecified_ addresses (zero address), we'll pick a source
// port from one of these listener's.
// 3. Otherwise, we'll let the system pick the source port.
func (d *dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// We only check this case if the user is listening on a specific address (loopback or
// otherwise). Generally, users will listen on the "unspecified" address (0.0.0.0 or ::) and
// we can skip this section.
//
// This lets us avoid resolving the address twice, in most cases.
if len(d.specific) > 0 || len(d.loopback) > 0 {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
ip := tcpAddr.IP
if !ip.IsLoopback() && !ip.IsGlobalUnicast() {
return nil, fmt.Errorf("undialable IP: %s", ip)
}
// If we're listening on some specific address and that specific address happens to
// be the preferred source address for the target destination address, we try to
// dial with that address/port.
//
// We skip this check if we _aren't_ listening on any specific addresses, because
// checking routing tables can be expensive and users rarely listen on specific IP
// addresses.
if len(d.specific) > 0 {
if router, err := netroute.New(); err == nil {
if _, _, preferredSrc, err := router.Route(ip); err == nil {
for _, optAddr := range d.specific {
if optAddr.IP.Equal(preferredSrc) {
return reuseDial(ctx, optAddr, network, addr)
}
}
}
}
}
// Otherwise, if we are listening on a loopback address and the destination is also
// a loopback address, use the port from our loopback listener.
if len(d.loopback) > 0 && ip.IsLoopback() {
return reuseDial(ctx, randAddr(d.loopback), network, addr)
}
}
// If we're listening on any uspecified addresses, use a randomly chosen port from one of
// these listeners.
if len(d.unspecified) > 0 {
return reuseDial(ctx, randAddr(d.unspecified), network, addr)
}
// Finally, just pick a random port.
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
func newDialer(listeners map[*listener]struct{}) *dialer {
specific := make([]*net.TCPAddr, 0)
loopback := make([]*net.TCPAddr, 0)
unspecified := make([]*net.TCPAddr, 0)
for l := range listeners {
addr := l.Addr().(*net.TCPAddr)
if addr.IP.IsLoopback() {
loopback = append(loopback, addr)
} else if addr.IP.IsUnspecified() {
unspecified = append(unspecified, addr)
} else {
specific = append(specific, addr)
}
}
return &dialer{
specific: specific,
loopback: loopback,
unspecified: unspecified,
}
}

View File

@@ -0,0 +1,80 @@
package reuseport
import (
"net"
"github.com/libp2p/go-reuseport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type listener struct {
manet.Listener
network *network
}
func (l *listener) Close() error {
l.network.mu.Lock()
delete(l.network.listeners, l)
l.network.dialer = nil
l.network.mu.Unlock()
return l.Listener.Close()
}
// Listen listens on the given multiaddr.
//
// If reuseport is supported, it will be enabled for this listener and future
// dials from this transport may reuse the port.
//
// Note: You can listen on the same multiaddr as many times as you want
// (although only *one* listener will end up handling the inbound connection).
func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) {
nw, naddr, err := manet.DialArgs(laddr)
if err != nil {
return nil, err
}
var n *network
switch nw {
case "tcp4":
n = &t.v4
case "tcp6":
n = &t.v6
default:
return nil, ErrWrongProto
}
if !reuseport.Available() {
return manet.Listen(laddr)
}
nl, err := reuseport.Listen(nw, naddr)
if err != nil {
return manet.Listen(laddr)
}
if _, ok := nl.Addr().(*net.TCPAddr); !ok {
nl.Close()
return nil, ErrWrongProto
}
malist, err := manet.WrapNetListener(nl)
if err != nil {
nl.Close()
return nil, err
}
list := &listener{
Listener: malist,
network: n,
}
n.mu.Lock()
defer n.mu.Unlock()
if n.listeners == nil {
n.listeners = make(map[*listener]struct{})
}
n.listeners[list] = struct{}{}
n.dialer = nil
return list, nil
}

View File

@@ -0,0 +1,35 @@
package reuseport
import (
"context"
"net"
"github.com/libp2p/go-reuseport"
)
var fallbackDialer net.Dialer
// Dials using reuseport and then redials normally if that fails.
func reuseDial(ctx context.Context, laddr *net.TCPAddr, network, raddr string) (con net.Conn, err error) {
if laddr == nil {
return fallbackDialer.DialContext(ctx, network, raddr)
}
d := net.Dialer{
LocalAddr: laddr,
Control: reuseport.Control,
}
con, err = d.DialContext(ctx, network, raddr)
if err == nil {
return con, nil
}
if reuseErrShouldRetry(err) && ctx.Err() == nil {
// We could have an existing socket open or we could have one
// stuck in TIME-WAIT.
log.Debugf("failed to reuse port, will try again with a random port: %s", err)
con, err = fallbackDialer.DialContext(ctx, network, raddr)
}
return con, err
}

View File

@@ -0,0 +1,44 @@
package reuseport
import (
"net"
"os"
)
const (
EADDRINUSE = "address in use"
ECONNREFUSED = "connection refused"
)
// reuseErrShouldRetry diagnoses whether to retry after a reuse error.
// if we failed to bind, we should retry. if bind worked and this is a
// real dial error (remote end didnt answer) then we should not retry.
func reuseErrShouldRetry(err error) bool {
if err == nil {
return false // hey, it worked! no need to retry.
}
// if it's a network timeout error, it's a legitimate failure.
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return false
}
e, ok := err.(*net.OpError)
if !ok {
return true
}
e1, ok := e.Err.(*os.PathError)
if !ok {
return true
}
switch e1.Err.Error() {
case EADDRINUSE:
return true
case ECONNREFUSED:
return false
default:
return true // optimistically default to retry.
}
}

View File

@@ -0,0 +1,36 @@
//go:build !plan9
package reuseport
import (
"net"
"syscall"
)
// reuseErrShouldRetry diagnoses whether to retry after a reuse error.
// if we failed to bind, we should retry. if bind worked and this is a
// real dial error (remote end didnt answer) then we should not retry.
func reuseErrShouldRetry(err error) bool {
if err == nil {
return false // hey, it worked! no need to retry.
}
// if it's a network timeout error, it's a legitimate failure.
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return false
}
errno, ok := err.(syscall.Errno)
if !ok { // not an errno? who knows what this is. retry.
return true
}
switch errno {
case syscall.EADDRINUSE, syscall.EADDRNOTAVAIL:
return true // failure to bind. retry.
case syscall.ECONNREFUSED:
return false // real dial error
default:
return true // optimistically default to retry.
}
}

View File

@@ -0,0 +1,35 @@
// Package reuseport provides a basic transport for automatically (and intelligently) reusing TCP ports.
//
// To use, construct a new Transport and configure listeners tr.Listen(...).
// When dialing (tr.Dial(...)), the transport will attempt to reuse the ports it's currently listening on,
// choosing the best one depending on the destination address.
//
// It is recommended to set SO_LINGER to 0 for all connections, otherwise
// reusing the port may fail when re-dialing a recently closed connection.
// See https://hea-www.harvard.edu/~fine/Tech/addrinuse.html for details.
package reuseport
import (
"errors"
"sync"
logging "github.com/ipfs/go-log/v2"
)
var log = logging.Logger("reuseport-transport")
// ErrWrongProto is returned when dialing a protocol other than tcp.
var ErrWrongProto = errors.New("can only dial TCP over IPv4 or IPv6")
// Transport is a TCP reuse transport that reuses listener ports.
// The zero value is safe to use.
type Transport struct {
v4 network
v6 network
}
type network struct {
mu sync.RWMutex
listeners map[*listener]struct{}
dialer *dialer
}

View File

@@ -0,0 +1,39 @@
package swarm
import (
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
var lowTimeoutFilters = ma.NewFilters()
func init() {
for _, p := range []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.0.0/ipcidr/29",
"/ip4/192.0.0.8/ipcidr/32",
"/ip4/192.0.0.170/ipcidr/32",
"/ip4/192.0.0.171/ipcidr/32",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
} {
f, err := ma.NewMultiaddr(p)
if err != nil {
panic("error in lowTimeoutFilters init: " + err.Error())
}
ipnet, err := manet.MultiaddrToIPNet(f)
if err != nil {
panic("error in lowTimeoutFilters init: " + err.Error())
}
lowTimeoutFilters.AddFilter(*ipnet, ma.ActionDeny)
}
}

View File

@@ -0,0 +1,276 @@
package swarm
import (
"fmt"
"sync"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type blackHoleState int
const (
blackHoleStateProbing blackHoleState = iota
blackHoleStateAllowed
blackHoleStateBlocked
)
func (st blackHoleState) String() string {
switch st {
case blackHoleStateProbing:
return "Probing"
case blackHoleStateAllowed:
return "Allowed"
case blackHoleStateBlocked:
return "Blocked"
default:
return fmt.Sprintf("Unknown %d", st)
}
}
type blackHoleResult int
const (
blackHoleResultAllowed blackHoleResult = iota
blackHoleResultProbing
blackHoleResultBlocked
)
// blackHoleFilter provides black hole filtering for dials. This filter should be used in
// concert with a UDP of IPv6 address filter to detect UDP or IPv6 black hole. In a black
// holed environments dial requests are blocked and only periodic probes to check the
// state of the black hole are allowed.
//
// Requests are blocked if the number of successes in the last n dials is less than
// minSuccesses. If a request succeeds in Blocked state, the filter state is reset and n
// subsequent requests are allowed before reevaluating black hole state. Dials cancelled
// when some other concurrent dial succeeded are counted as failures. A sufficiently large
// n prevents false negatives in such cases.
type blackHoleFilter struct {
// n serves the dual purpose of being the minimum number of requests after which we
// probe the state of the black hole in blocked state and the minimum number of
// completed dials required before evaluating black hole state.
n int
// minSuccesses is the minimum number of Success required in the last n dials
// to consider we are not blocked.
minSuccesses int
// name for the detector.
name string
// requests counts number of dial requests to peers. We handle request at a peer
// level and record results at individual address dial level.
requests int
// dialResults of the last `n` dials. A successful dial is true.
dialResults []bool
// successes is the count of successful dials in outcomes
successes int
// state is the current state of the detector
state blackHoleState
mu sync.Mutex
metricsTracer MetricsTracer
}
// RecordResult records the outcome of a dial. A successful dial will change the state
// of the filter to Allowed. A failed dial only blocks subsequent requests if the success
// fraction over the last n outcomes is less than the minSuccessFraction of the filter.
func (b *blackHoleFilter) RecordResult(success bool) {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == blackHoleStateBlocked && success {
// If the call succeeds in a blocked state we reset to allowed.
// This is better than slowly accumulating values till we cross the minSuccessFraction
// threshold since a blackhole is a binary property.
b.reset()
return
}
if success {
b.successes++
}
b.dialResults = append(b.dialResults, success)
if len(b.dialResults) > b.n {
if b.dialResults[0] {
b.successes--
}
b.dialResults = b.dialResults[1:]
}
b.updateState()
b.trackMetrics()
}
// HandleRequest returns the result of applying the black hole filter for the request.
func (b *blackHoleFilter) HandleRequest() blackHoleResult {
b.mu.Lock()
defer b.mu.Unlock()
b.requests++
b.trackMetrics()
if b.state == blackHoleStateAllowed {
return blackHoleResultAllowed
} else if b.state == blackHoleStateProbing || b.requests%b.n == 0 {
return blackHoleResultProbing
} else {
return blackHoleResultBlocked
}
}
func (b *blackHoleFilter) reset() {
b.successes = 0
b.dialResults = b.dialResults[:0]
b.requests = 0
b.updateState()
}
func (b *blackHoleFilter) updateState() {
st := b.state
if len(b.dialResults) < b.n {
b.state = blackHoleStateProbing
} else if b.successes >= b.minSuccesses {
b.state = blackHoleStateAllowed
} else {
b.state = blackHoleStateBlocked
}
if st != b.state {
log.Debugf("%s blackHoleDetector state changed from %s to %s", b.name, st, b.state)
}
}
func (b *blackHoleFilter) trackMetrics() {
if b.metricsTracer == nil {
return
}
nextRequestAllowedAfter := 0
if b.state == blackHoleStateBlocked {
nextRequestAllowedAfter = b.n - (b.requests % b.n)
}
successFraction := 0.0
if len(b.dialResults) > 0 {
successFraction = float64(b.successes) / float64(len(b.dialResults))
}
b.metricsTracer.UpdatedBlackHoleFilterState(
b.name,
b.state,
nextRequestAllowedAfter,
successFraction,
)
}
// blackHoleDetector provides UDP and IPv6 black hole detection using a `blackHoleFilter`
// for each. For details of the black hole detection logic see `blackHoleFilter`.
//
// black hole filtering is done at a peer dial level to ensure that periodic probes to
// detect change of the black hole state are actually dialed and are not skipped
// because of dial prioritisation logic.
type blackHoleDetector struct {
udp, ipv6 *blackHoleFilter
}
// FilterAddrs filters the peer's addresses removing black holed addresses
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
hasUDP, hasIPv6 := false, false
for _, a := range addrs {
if !manet.IsPublicAddr(a) {
continue
}
if isProtocolAddr(a, ma.P_UDP) {
hasUDP = true
}
if isProtocolAddr(a, ma.P_IP6) {
hasIPv6 = true
}
}
udpRes := blackHoleResultAllowed
if d.udp != nil && hasUDP {
udpRes = d.udp.HandleRequest()
}
ipv6Res := blackHoleResultAllowed
if d.ipv6 != nil && hasIPv6 {
ipv6Res = d.ipv6.HandleRequest()
}
return ma.FilterAddrs(
addrs,
func(a ma.Multiaddr) bool {
if !manet.IsPublicAddr(a) {
return true
}
// allow all UDP addresses while probing irrespective of IPv6 black hole state
if udpRes == blackHoleResultProbing && isProtocolAddr(a, ma.P_UDP) {
return true
}
// allow all IPv6 addresses while probing irrespective of UDP black hole state
if ipv6Res == blackHoleResultProbing && isProtocolAddr(a, ma.P_IP6) {
return true
}
if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) {
return false
}
if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) {
return false
}
return true
},
)
}
// RecordResult updates the state of the relevant `blackHoleFilter`s for addr
func (d *blackHoleDetector) RecordResult(addr ma.Multiaddr, success bool) {
if !manet.IsPublicAddr(addr) {
return
}
if d.udp != nil && isProtocolAddr(addr, ma.P_UDP) {
d.udp.RecordResult(success)
}
if d.ipv6 != nil && isProtocolAddr(addr, ma.P_IP6) {
d.ipv6.RecordResult(success)
}
}
// blackHoleConfig is the config used for black hole detection
type blackHoleConfig struct {
// Enabled enables black hole detection
Enabled bool
// N is the size of the sliding window used to evaluate black hole state
N int
// MinSuccesses is the minimum number of successes out of N required to not
// block requests
MinSuccesses int
}
func newBlackHoleDetector(udpConfig, ipv6Config blackHoleConfig, mt MetricsTracer) *blackHoleDetector {
d := &blackHoleDetector{}
if udpConfig.Enabled {
d.udp = &blackHoleFilter{
n: udpConfig.N,
minSuccesses: udpConfig.MinSuccesses,
name: "UDP",
metricsTracer: mt,
}
}
if ipv6Config.Enabled {
d.ipv6 = &blackHoleFilter{
n: ipv6Config.N,
minSuccesses: ipv6Config.MinSuccesses,
name: "IPv6",
metricsTracer: mt,
}
}
return d
}

View File

@@ -0,0 +1,49 @@
package swarm
import "time"
// InstantTimer is a timer that triggers at some instant rather than some duration
type InstantTimer interface {
Reset(d time.Time) bool
Stop() bool
Ch() <-chan time.Time
}
// Clock is a clock that can create timers that trigger at some
// instant rather than some duration
type Clock interface {
Now() time.Time
Since(t time.Time) time.Duration
InstantTimer(when time.Time) InstantTimer
}
type RealTimer struct{ t *time.Timer }
var _ InstantTimer = (*RealTimer)(nil)
func (t RealTimer) Ch() <-chan time.Time {
return t.t.C
}
func (t RealTimer) Reset(d time.Time) bool {
return t.t.Reset(time.Until(d))
}
func (t RealTimer) Stop() bool {
return t.t.Stop()
}
type RealClock struct{}
var _ Clock = RealClock{}
func (RealClock) Now() time.Time {
return time.Now()
}
func (RealClock) Since(t time.Time) time.Duration {
return time.Since(t)
}
func (RealClock) InstantTimer(when time.Time) InstantTimer {
t := time.NewTimer(time.Until(when))
return &RealTimer{t}
}

View File

@@ -0,0 +1,71 @@
package swarm
import (
"fmt"
"os"
"strings"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
)
// maxDialDialErrors is the maximum number of dial errors we record
const maxDialDialErrors = 16
// DialError is the error type returned when dialing.
type DialError struct {
Peer peer.ID
DialErrors []TransportError
Cause error
Skipped int
}
func (e *DialError) Timeout() bool {
return os.IsTimeout(e.Cause)
}
func (e *DialError) recordErr(addr ma.Multiaddr, err error) {
if len(e.DialErrors) >= maxDialDialErrors {
e.Skipped++
return
}
e.DialErrors = append(e.DialErrors, TransportError{
Address: addr,
Cause: err,
})
}
func (e *DialError) Error() string {
var builder strings.Builder
fmt.Fprintf(&builder, "failed to dial %s:", e.Peer)
if e.Cause != nil {
fmt.Fprintf(&builder, " %s", e.Cause)
}
for _, te := range e.DialErrors {
fmt.Fprintf(&builder, "\n * [%s] %s", te.Address, te.Cause)
}
if e.Skipped > 0 {
fmt.Fprintf(&builder, "\n ... skipping %d errors ...", e.Skipped)
}
return builder.String()
}
// Unwrap implements https://godoc.org/golang.org/x/xerrors#Wrapper.
func (e *DialError) Unwrap() error {
return e.Cause
}
var _ error = (*DialError)(nil)
// TransportError is the error returned when dialing a specific address.
type TransportError struct {
Address ma.Multiaddr
Cause error
}
func (e *TransportError) Error() string {
return fmt.Sprintf("failed to dial %s: %s", e.Address, e.Cause)
}
var _ error = (*TransportError)(nil)

View File

@@ -0,0 +1,199 @@
package swarm
import (
"sort"
"strconv"
"time"
"github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// The 250ms value is from happy eyeballs RFC 8305. This is a rough estimate of 1 RTT
const (
// duration by which TCP dials are delayed relative to the last QUIC dial
PublicTCPDelay = 250 * time.Millisecond
PrivateTCPDelay = 30 * time.Millisecond
// duration by which QUIC dials are delayed relative to previous QUIC dial
PublicQUICDelay = 250 * time.Millisecond
PrivateQUICDelay = 30 * time.Millisecond
// RelayDelay is the duration by which relay dials are delayed relative to direct addresses
RelayDelay = 500 * time.Millisecond
)
// NoDelayDialRanker ranks addresses with no delay. This is useful for simultaneous connect requests.
func NoDelayDialRanker(addrs []ma.Multiaddr) []network.AddrDelay {
return getAddrDelay(addrs, 0, 0, 0)
}
// DefaultDialRanker determines the ranking of outgoing connection attempts.
//
// Addresses are grouped into three distinct groups:
//
// - private addresses (localhost and local networks (RFC 1918))
// - public addresses
// - relay addresses
//
// Within each group, the addresses are ranked according to the ranking logic described below.
// We then dial addresses according to this ranking, with short timeouts applied between dial attempts.
// This ranking logic dramatically reduces the number of simultaneous dial attempts, while introducing
// no additional latency in the vast majority of cases.
//
// Private and public address groups are dialed in parallel.
// Dialing relay addresses is delayed by 500 ms, if we have any non-relay alternatives.
//
// Within each group (private, public, relay addresses) we apply the following ranking logic:
//
// 1. If both IPv6 QUIC and IPv4 QUIC addresses are present, we do a Happy Eyeballs RFC 8305 style ranking.
// First dial the IPv6 QUIC address with the lowest port. After this we dial the IPv4 QUIC address with
// the lowest port delayed by 250ms (PublicQUICDelay) for public addresses, and 30ms (PrivateQUICDelay)
// for local addresses. After this we dial all the rest of the addresses delayed by 250ms (PublicQUICDelay)
// for public addresses, and 30ms (PrivateQUICDelay) for local addresses.
// 2. If only one of QUIC IPv6 or QUIC IPv4 addresses are present, dial the QUIC address with the lowest port
// first. After this we dial the rest of the QUIC addresses delayed by 250ms (PublicQUICDelay) for public
// addresses, and 30ms (PrivateQUICDelay) for local addresses.
// 3. If a QUIC or WebTransport address is present, TCP addresses dials are delayed relative to the last QUIC dial:
// We prefer to end up with a QUIC connection. For public addresses, the delay introduced is 250ms (PublicTCPDelay),
// and for private addresses 30ms (PrivateTCPDelay).
//
// We dial lowest ports first for QUIC addresses as they are more likely to be the listen port.
func DefaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay {
relay, addrs := filterAddrs(addrs, isRelayAddr)
pvt, addrs := filterAddrs(addrs, manet.IsPrivateAddr)
public, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_IP4) || isProtocolAddr(a, ma.P_IP6) })
var relayOffset time.Duration
if len(public) > 0 {
// if there is a public direct address available delay relay dials
relayOffset = RelayDelay
}
res := make([]network.AddrDelay, 0, len(addrs))
for i := 0; i < len(addrs); i++ {
res = append(res, network.AddrDelay{Addr: addrs[i], Delay: 0})
}
res = append(res, getAddrDelay(pvt, PrivateTCPDelay, PrivateQUICDelay, 0)...)
res = append(res, getAddrDelay(public, PublicTCPDelay, PublicQUICDelay, 0)...)
res = append(res, getAddrDelay(relay, PublicTCPDelay, PublicQUICDelay, relayOffset)...)
return res
}
// getAddrDelay ranks a group of addresses according to the ranking logic explained in
// documentation for defaultDialRanker.
// offset is used to delay all addresses by a fixed duration. This is useful for delaying all relay
// addresses relative to direct addresses.
func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, quicDelay time.Duration,
offset time.Duration) []network.AddrDelay {
sort.Slice(addrs, func(i, j int) bool { return score(addrs[i]) < score(addrs[j]) })
// If the first address is (QUIC, IPv6), make the second address (QUIC, IPv4).
happyEyeballs := false
if len(addrs) > 0 {
if isQUICAddr(addrs[0]) && isProtocolAddr(addrs[0], ma.P_IP6) {
for i := 1; i < len(addrs); i++ {
if isQUICAddr(addrs[i]) && isProtocolAddr(addrs[i], ma.P_IP4) {
// make IPv4 address the second element
if i > 1 {
a := addrs[i]
copy(addrs[2:], addrs[1:i])
addrs[1] = a
}
happyEyeballs = true
break
}
}
}
}
res := make([]network.AddrDelay, 0, len(addrs))
var totalTCPDelay time.Duration
for i, addr := range addrs {
var delay time.Duration
switch {
case isQUICAddr(addr):
// For QUIC addresses we dial an IPv6 address, then after quicDelay an IPv4
// address, then after quicDelay we dial rest of the addresses.
if i == 1 {
delay = quicDelay
}
if i > 1 && happyEyeballs {
delay = 2 * quicDelay
} else if i > 1 {
delay = quicDelay
}
totalTCPDelay = delay + tcpDelay
case isProtocolAddr(addr, ma.P_TCP):
delay = totalTCPDelay
}
res = append(res, network.AddrDelay{Addr: addr, Delay: offset + delay})
}
return res
}
// score scores a multiaddress for dialing delay. Lower is better.
// The lower 16 bits of the result are the port. Low ports are ranked higher because they're
// more likely to be listen addresses.
// The addresses are ranked as:
// QUICv1 IPv6 > QUICdraft29 IPv6 > QUICv1 IPv4 > QUICdraft29 IPv4 >
// WebTransport IPv6 > WebTransport IPv4 > TCP IPv6 > TCP IPv4
func score(a ma.Multiaddr) int {
ip4Weight := 0
if isProtocolAddr(a, ma.P_IP4) {
ip4Weight = 1 << 18
}
if _, err := a.ValueForProtocol(ma.P_WEBTRANSPORT); err == nil {
p, _ := a.ValueForProtocol(ma.P_UDP)
pi, _ := strconv.Atoi(p)
return ip4Weight + (1 << 19) + pi
}
if _, err := a.ValueForProtocol(ma.P_QUIC); err == nil {
p, _ := a.ValueForProtocol(ma.P_UDP)
pi, _ := strconv.Atoi(p)
return ip4Weight + pi + (1 << 17)
}
if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil {
p, _ := a.ValueForProtocol(ma.P_UDP)
pi, _ := strconv.Atoi(p)
return ip4Weight + pi
}
if p, err := a.ValueForProtocol(ma.P_TCP); err == nil {
pi, _ := strconv.Atoi(p)
return ip4Weight + pi + (1 << 20)
}
return (1 << 30)
}
func isProtocolAddr(a ma.Multiaddr, p int) bool {
found := false
ma.ForEach(a, func(c ma.Component) bool {
if c.Protocol().Code == p {
found = true
return false
}
return true
})
return found
}
func isQUICAddr(a ma.Multiaddr) bool {
return isProtocolAddr(a, ma.P_QUIC) || isProtocolAddr(a, ma.P_QUIC_V1)
}
// filterAddrs filters an address slice in place
func filterAddrs(addrs []ma.Multiaddr, f func(a ma.Multiaddr) bool) (filtered, rest []ma.Multiaddr) {
j := 0
for i := 0; i < len(addrs); i++ {
if f(addrs[i]) {
addrs[i], addrs[j] = addrs[j], addrs[i]
j++
}
}
return addrs[:j], addrs[j:]
}

View File

@@ -0,0 +1,109 @@
package swarm
import (
"context"
"sync"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
)
// dialWorkerFunc is used by dialSync to spawn a new dial worker
type dialWorkerFunc func(peer.ID, <-chan dialRequest)
// newDialSync constructs a new dialSync
func newDialSync(worker dialWorkerFunc) *dialSync {
return &dialSync{
dials: make(map[peer.ID]*activeDial),
dialWorker: worker,
}
}
// dialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
type dialSync struct {
mutex sync.Mutex
dials map[peer.ID]*activeDial
dialWorker dialWorkerFunc
}
type activeDial struct {
refCnt int
ctx context.Context
cancel func()
reqch chan dialRequest
}
func (ad *activeDial) close() {
ad.cancel()
close(ad.reqch)
}
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
dialCtx := ad.ctx
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
dialCtx = network.WithForceDirectDial(dialCtx, reason)
}
if simConnect, isClient, reason := network.GetSimultaneousConnect(ctx); simConnect {
dialCtx = network.WithSimultaneousConnect(dialCtx, isClient, reason)
}
resch := make(chan dialResponse, 1)
select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case <-ctx.Done():
return nil, ctx.Err()
}
select {
case res := <-resch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (ds *dialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.mutex.Lock()
defer ds.mutex.Unlock()
actd, ok := ds.dials[p]
if !ok {
// This code intentionally uses the background context. Otherwise, if the first call
// to Dial is canceled, subsequent dial calls will also be canceled.
ctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
ctx: ctx,
cancel: cancel,
reqch: make(chan dialRequest),
}
go ds.dialWorker(p, actd.reqch)
ds.dials[p] = actd
}
// increase ref count before dropping mutex
actd.refCnt++
return actd, nil
}
// Dial initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
func (ds *dialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) {
ad, err := ds.getActiveDial(p)
if err != nil {
return nil, err
}
defer func() {
ds.mutex.Lock()
defer ds.mutex.Unlock()
ad.refCnt--
if ad.refCnt == 0 {
ad.close()
delete(ds.dials, p)
}
}()
return ad.dial(ctx)
}

View File

@@ -0,0 +1,492 @@
package swarm
import (
"context"
"math"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
)
// /////////////////////////////////////////////////////////////////////////////////
// lo and behold, The Dialer
// TODO explain how all this works
// ////////////////////////////////////////////////////////////////////////////////
// dialRequest is structure used to request dials to the peer associated with a
// worker loop
type dialRequest struct {
// ctx is the context that may be used for the request
// if another concurrent request is made, any of the concurrent request's ctx may be used for
// dials to the peer's addresses
// ctx for simultaneous connect requests have higher priority than normal requests
ctx context.Context
// resch is the channel used to send the response for this query
resch chan dialResponse
}
// dialResponse is the response sent to dialRequests on the request's resch channel
type dialResponse struct {
// conn is the connection to the peer on success
conn *Conn
// err is the error in dialing the peer
// nil on connection success
err error
}
// pendRequest is used to track progress on a dialRequest.
type pendRequest struct {
// req is the original dialRequest
req dialRequest
// err comprises errors of all failed dials
err *DialError
// addrs are the addresses on which we are waiting for pending dials
// At the time of creation addrs is initialised to all the addresses of the peer. On a failed dial,
// the addr is removed from the map and err is updated. On a successful dial, the dialRequest is
// completed and response is sent with the connection
addrs map[string]struct{}
}
// addrDial tracks dials to a particular multiaddress.
type addrDial struct {
// addr is the address dialed
addr ma.Multiaddr
// ctx is the context used for dialing the address
ctx context.Context
// conn is the established connection on success
conn *Conn
// err is the err on dialing the address
err error
// requests is the list of pendRequests interested in this dial
// the value in the slice is the request number assigned to this request by the dialWorker
requests []int
// dialed indicates whether we have triggered the dial to the address
dialed bool
// createdAt is the time this struct was created
createdAt time.Time
// dialRankingDelay is the delay in dialing this address introduced by the ranking logic
dialRankingDelay time.Duration
}
// dialWorker synchronises concurrent dials to a peer. It ensures that we make at most one dial to a
// peer's address
type dialWorker struct {
s *Swarm
peer peer.ID
// reqch is used to send dial requests to the worker. close reqch to end the worker loop
reqch <-chan dialRequest
// reqno is the request number used to track different dialRequests for a peer.
// Each incoming request is assigned a reqno. This reqno is used in pendingRequests and in
// addrDial objects in trackedDials to track this request
reqno int
// pendingRequests maps reqno to the pendRequest object for a dialRequest
pendingRequests map[int]*pendRequest
// trackedDials tracks dials to the peers addresses. An entry here is used to ensure that
// we dial an address at most once
trackedDials map[string]*addrDial
// resch is used to receive response for dials to the peers addresses.
resch chan dialResult
connected bool // true when a connection has been successfully established
// for testing
wg sync.WaitGroup
cl Clock
}
func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dialWorker {
if cl == nil {
cl = RealClock{}
}
return &dialWorker{
s: s,
peer: p,
reqch: reqch,
pendingRequests: make(map[int]*pendRequest),
trackedDials: make(map[string]*addrDial),
resch: make(chan dialResult),
cl: cl,
}
}
// loop implements the core dial worker loop. Requests are received on w.reqch.
// The loop exits when w.reqch is closed.
func (w *dialWorker) loop() {
w.wg.Add(1)
defer w.wg.Done()
defer w.s.limiter.clearAllPeerDials(w.peer)
// dq is used to pace dials to different addresses of the peer
dq := newDialQueue()
// dialsInFlight is the number of dials in flight.
dialsInFlight := 0
startTime := w.cl.Now()
// dialTimer is the dialTimer used to trigger dials
dialTimer := w.cl.InstantTimer(startTime.Add(math.MaxInt64))
timerRunning := true
// scheduleNextDial updates timer for triggering the next dial
scheduleNextDial := func() {
if timerRunning && !dialTimer.Stop() {
<-dialTimer.Ch()
}
timerRunning = false
if dq.len() > 0 {
if dialsInFlight == 0 && !w.connected {
// if there are no dials in flight, trigger the next dials immediately
dialTimer.Reset(startTime)
} else {
dialTimer.Reset(startTime.Add(dq.top().Delay))
}
timerRunning = true
}
}
// totalDials is used to track number of dials made by this worker for metrics
totalDials := 0
loop:
for {
// The loop has three parts
// 1. Input requests are received on w.reqch. If a suitable connection is not available we create
// a pendRequest object to track the dialRequest and add the addresses to dq.
// 2. Addresses from the dialQueue are dialed at appropriate time intervals depending on delay logic.
// We are notified of the completion of these dials on w.resch.
// 3. Responses for dials are received on w.resch. On receiving a response, we updated the pendRequests
// interested in dials on this address.
select {
case req, ok := <-w.reqch:
if !ok {
if w.s.metricsTracer != nil {
w.s.metricsTracer.DialCompleted(w.connected, totalDials)
}
return
}
// We have received a new request. If we do not have a suitable connection,
// track this dialRequest with a pendRequest.
// Enqueue the peer's addresses relevant to this request in dq and
// track dials to the addresses relevant to this request.
c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
if c != nil || err != nil {
req.resch <- dialResponse{conn: c, err: err}
continue loop
}
addrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: err}
continue loop
}
// get the delays to dial these addrs from the swarms dialRanker
simConnect, _, _ := network.GetSimultaneousConnect(req.ctx)
addrRanking := w.rankAddrs(addrs, simConnect)
addrDelay := make(map[string]time.Duration, len(addrRanking))
// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer},
addrs: make(map[string]struct{}, len(addrRanking)),
}
for _, adelay := range addrRanking {
pr.addrs[string(adelay.Addr.Bytes())] = struct{}{}
addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay
}
// Check if dials to any of the addrs have completed already
// If they have errored, record the error in pr. If they have succeeded,
// respond with the connection.
// If they are pending, add them to tojoin.
// If we haven't seen any of the addresses before, add them to todial.
var todial []ma.Multiaddr
var tojoin []*addrDial
for _, adelay := range addrRanking {
ad, ok := w.trackedDials[string(adelay.Addr.Bytes())]
if !ok {
todial = append(todial, adelay.Addr)
continue
}
if ad.conn != nil {
// dial to this addr was successful, complete the request
req.resch <- dialResponse{conn: ad.conn}
continue loop
}
if ad.err != nil {
// dial to this addr errored, accumulate the error
pr.err.recordErr(ad.addr, ad.err)
delete(pr.addrs, string(ad.addr.Bytes()))
continue
}
// dial is still pending, add to the join list
tojoin = append(tojoin, ad)
}
if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
req.resch <- dialResponse{err: pr.err}
continue loop
}
// The request has some pending or new dials. We assign this request a request number.
// This value of w.reqno is used to track this request in all the structures
w.reqno++
w.pendingRequests[w.reqno] = pr
for _, ad := range tojoin {
if !ad.dialed {
// we haven't dialed this address. update the ad.ctx to have simultaneous connect values
// set correctly
if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect {
if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect {
ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason)
// update the element in dq to use the simultaneous connect delay.
dq.Add(network.AddrDelay{
Addr: ad.addr,
Delay: addrDelay[string(ad.addr.Bytes())],
})
}
}
}
// add the request to the addrDial
ad.requests = append(ad.requests, w.reqno)
}
if len(todial) > 0 {
now := time.Now()
// these are new addresses, track them and add them to dq
for _, a := range todial {
w.trackedDials[string(a.Bytes())] = &addrDial{
addr: a,
ctx: req.ctx,
requests: []int{w.reqno},
createdAt: now,
}
dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]})
}
}
// setup dialTimer for updates to dq
scheduleNextDial()
case <-dialTimer.Ch():
// It's time to dial the next batch of addresses.
// We don't check the delay of the addresses received from the queue here
// because if the timer triggered before the delay, it means that all
// the inflight dials have errored and we should dial the next batch of
// addresses
now := time.Now()
for _, adelay := range dq.NextBatch() {
// spawn the dial
ad, ok := w.trackedDials[string(adelay.Addr.Bytes())]
if !ok {
log.Errorf("SWARM BUG: no entry for address %s in trackedDials", adelay.Addr)
continue
}
ad.dialed = true
ad.dialRankingDelay = now.Sub(ad.createdAt)
err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch)
if err != nil {
// Errored without attempting a dial. This happens in case of
// backoff or black hole.
w.dispatchError(ad, err)
} else {
// the dial was successful. update inflight dials
dialsInFlight++
totalDials++
}
}
timerRunning = false
// schedule more dials
scheduleNextDial()
case res := <-w.resch:
// A dial to an address has completed.
// Update all requests waiting on this address. On success, complete the request.
// On error, record the error
dialsInFlight--
ad, ok := w.trackedDials[string(res.Addr.Bytes())]
if !ok {
log.Errorf("SWARM BUG: no entry for address %s in trackedDials", res.Addr)
if res.Conn != nil {
res.Conn.Close()
}
continue
}
if res.Conn != nil {
// we got a connection, add it to the swarm
conn, err := w.s.addConn(res.Conn, network.DirOutbound)
if err != nil {
// oops no, we failed to add it to the swarm
res.Conn.Close()
w.dispatchError(ad, err)
continue loop
}
// request succeeded, respond to all pending requests
for _, reqno := range ad.requests {
pr, ok := w.pendingRequests[reqno]
if !ok {
// some other dial for this request succeeded before this one
continue
}
pr.req.resch <- dialResponse{conn: conn}
delete(w.pendingRequests, reqno)
}
ad.conn = conn
ad.requests = nil
if !w.connected {
w.connected = true
if w.s.metricsTracer != nil {
w.s.metricsTracer.DialRankingDelay(ad.dialRankingDelay)
}
}
continue loop
}
// it must be an error -- add backoff if applicable and dispatch
// ErrDialRefusedBlackHole shouldn't end up here, just a safety check
if res.Err != ErrDialRefusedBlackHole && res.Err != context.Canceled && !w.connected {
// we only add backoff if there has not been a successful connection
// for consistency with the old dialer behavior.
w.s.backf.AddBackoff(w.peer, res.Addr)
} else if res.Err == ErrDialRefusedBlackHole {
log.Errorf("SWARM BUG: unexpected ErrDialRefusedBlackHole while dialing peer %s to addr %s",
w.peer, res.Addr)
}
w.dispatchError(ad, res.Err)
// Only schedule next dial on error.
// If we scheduleNextDial on success, we will end up making one dial more than
// required because the final successful dial will spawn one more dial
scheduleNextDial()
}
}
}
// dispatches an error to a specific addr dial
func (w *dialWorker) dispatchError(ad *addrDial, err error) {
ad.err = err
for _, reqno := range ad.requests {
pr, ok := w.pendingRequests[reqno]
if !ok {
// some other dial for this request succeeded before this one
continue
}
// accumulate the error
pr.err.recordErr(ad.addr, err)
delete(pr.addrs, string(ad.addr.Bytes()))
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from
// a simultaneous dial that started later and added new acceptable addrs
c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
pr.req.resch <- dialResponse{err: pr.err}
}
delete(w.pendingRequests, reqno)
}
}
ad.requests = nil
// if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests.
// this is necessary to support active listen scenarios, where a new dial comes in while
// another dial is in progress, and needs to do a direct connection without inhibitions from
// dial backoff.
if err == ErrDialBackoff {
delete(w.trackedDials, string(ad.addr.Bytes()))
}
}
// rankAddrs ranks addresses for dialing. if it's a simConnect request we
// dial all addresses immediately without any delay
func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []network.AddrDelay {
if isSimConnect {
return NoDelayDialRanker(addrs)
}
return w.s.dialRanker(addrs)
}
// dialQueue is a priority queue used to schedule dials
type dialQueue struct {
// q contains dials ordered by delay
q []network.AddrDelay
}
// newDialQueue returns a new dialQueue
func newDialQueue() *dialQueue {
return &dialQueue{q: make([]network.AddrDelay, 0, 16)}
}
// Add adds adelay to the queue. If another element exists in the queue with
// the same address, it replaces that element.
func (dq *dialQueue) Add(adelay network.AddrDelay) {
for i := 0; i < dq.len(); i++ {
if dq.q[i].Addr.Equal(adelay.Addr) {
if dq.q[i].Delay == adelay.Delay {
// existing element is the same. nothing to do
return
}
// remove the element
copy(dq.q[i:], dq.q[i+1:])
dq.q = dq.q[:len(dq.q)-1]
break
}
}
for i := 0; i < dq.len(); i++ {
if dq.q[i].Delay > adelay.Delay {
dq.q = append(dq.q, network.AddrDelay{}) // extend the slice
copy(dq.q[i+1:], dq.q[i:])
dq.q[i] = adelay
return
}
}
dq.q = append(dq.q, adelay)
}
// NextBatch returns all the elements in the queue with the highest priority
func (dq *dialQueue) NextBatch() []network.AddrDelay {
if dq.len() == 0 {
return nil
}
// i is the index of the second highest priority element
var i int
for i = 0; i < dq.len(); i++ {
if dq.q[i].Delay != dq.q[0].Delay {
break
}
}
res := dq.q[:i]
dq.q = dq.q[i:]
return res
}
// top returns the top element of the queue
func (dq *dialQueue) top() network.AddrDelay {
return dq.q[0]
}
// len returns the number of elements in the queue
func (dq *dialQueue) len() int {
return len(dq.q)
}

View File

@@ -0,0 +1,227 @@
package swarm
import (
"context"
"os"
"strconv"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
)
type dialResult struct {
Conn transport.CapableConn
Addr ma.Multiaddr
Err error
}
type dialJob struct {
addr ma.Multiaddr
peer peer.ID
ctx context.Context
resp chan dialResult
timeout time.Duration
}
func (dj *dialJob) cancelled() bool {
return dj.ctx.Err() != nil
}
type dialLimiter struct {
lk sync.Mutex
fdConsuming int
fdLimit int
waitingOnFd []*dialJob
dialFunc dialfunc
activePerPeer map[peer.ID]int
perPeerLimit int
waitingOnPeerLimit map[peer.ID][]*dialJob
}
type dialfunc func(context.Context, peer.ID, ma.Multiaddr) (transport.CapableConn, error)
func newDialLimiter(df dialfunc) *dialLimiter {
fd := ConcurrentFdDials
if env := os.Getenv("LIBP2P_SWARM_FD_LIMIT"); env != "" {
if n, err := strconv.ParseInt(env, 10, 32); err == nil {
fd = int(n)
}
}
return newDialLimiterWithParams(df, fd, DefaultPerPeerRateLimit)
}
func newDialLimiterWithParams(df dialfunc, fdLimit, perPeerLimit int) *dialLimiter {
return &dialLimiter{
fdLimit: fdLimit,
perPeerLimit: perPeerLimit,
waitingOnPeerLimit: make(map[peer.ID][]*dialJob),
activePerPeer: make(map[peer.ID]int),
dialFunc: df,
}
}
// freeFDToken frees FD token and if there are any schedules another waiting dialJob
// in it's place
func (dl *dialLimiter) freeFDToken() {
log.Debugf("[limiter] freeing FD token; waiting: %d; consuming: %d", len(dl.waitingOnFd), dl.fdConsuming)
dl.fdConsuming--
for len(dl.waitingOnFd) > 0 {
next := dl.waitingOnFd[0]
dl.waitingOnFd[0] = nil // clear out memory
dl.waitingOnFd = dl.waitingOnFd[1:]
if len(dl.waitingOnFd) == 0 {
// clear out memory.
dl.waitingOnFd = nil
}
// Skip over canceled dials instead of queuing up a goroutine.
if next.cancelled() {
dl.freePeerToken(next)
continue
}
dl.fdConsuming++
// we already have activePerPeer token at this point so we can just dial
go dl.executeDial(next)
return
}
}
func (dl *dialLimiter) freePeerToken(dj *dialJob) {
log.Debugf("[limiter] freeing peer token; peer %s; addr: %s; active for peer: %d; waiting on peer limit: %d",
dj.peer, dj.addr, dl.activePerPeer[dj.peer], len(dl.waitingOnPeerLimit[dj.peer]))
// release tokens in reverse order than we take them
dl.activePerPeer[dj.peer]--
if dl.activePerPeer[dj.peer] == 0 {
delete(dl.activePerPeer, dj.peer)
}
waitlist := dl.waitingOnPeerLimit[dj.peer]
for len(waitlist) > 0 {
next := waitlist[0]
waitlist[0] = nil // clear out memory
waitlist = waitlist[1:]
if len(waitlist) == 0 {
delete(dl.waitingOnPeerLimit, next.peer)
} else {
dl.waitingOnPeerLimit[next.peer] = waitlist
}
if next.cancelled() {
continue
}
dl.activePerPeer[next.peer]++ // just kidding, we still want this token
dl.addCheckFdLimit(next)
return
}
}
func (dl *dialLimiter) finishedDial(dj *dialJob) {
dl.lk.Lock()
defer dl.lk.Unlock()
if dl.shouldConsumeFd(dj.addr) {
dl.freeFDToken()
}
dl.freePeerToken(dj)
}
func (dl *dialLimiter) shouldConsumeFd(addr ma.Multiaddr) bool {
// we don't consume FD's for relay addresses for now as they will be consumed when the Relay Transport
// actually dials the Relay server. That dial call will also pass through this limiter with
// the address of the relay server i.e. non-relay address.
_, err := addr.ValueForProtocol(ma.P_CIRCUIT)
isRelay := err == nil
return !isRelay && isFdConsumingAddr(addr)
}
func (dl *dialLimiter) addCheckFdLimit(dj *dialJob) {
if dl.shouldConsumeFd(dj.addr) {
if dl.fdConsuming >= dl.fdLimit {
log.Debugf("[limiter] blocked dial waiting on FD token; peer: %s; addr: %s; consuming: %d; "+
"limit: %d; waiting: %d", dj.peer, dj.addr, dl.fdConsuming, dl.fdLimit, len(dl.waitingOnFd))
dl.waitingOnFd = append(dl.waitingOnFd, dj)
return
}
log.Debugf("[limiter] taking FD token: peer: %s; addr: %s; prev consuming: %d",
dj.peer, dj.addr, dl.fdConsuming)
// take token
dl.fdConsuming++
}
log.Debugf("[limiter] executing dial; peer: %s; addr: %s; FD consuming: %d; waiting: %d",
dj.peer, dj.addr, dl.fdConsuming, len(dl.waitingOnFd))
go dl.executeDial(dj)
}
func (dl *dialLimiter) addCheckPeerLimit(dj *dialJob) {
if dl.activePerPeer[dj.peer] >= dl.perPeerLimit {
log.Debugf("[limiter] blocked dial waiting on peer limit; peer: %s; addr: %s; active: %d; "+
"peer limit: %d; waiting: %d", dj.peer, dj.addr, dl.activePerPeer[dj.peer], dl.perPeerLimit,
len(dl.waitingOnPeerLimit[dj.peer]))
wlist := dl.waitingOnPeerLimit[dj.peer]
dl.waitingOnPeerLimit[dj.peer] = append(wlist, dj)
return
}
dl.activePerPeer[dj.peer]++
dl.addCheckFdLimit(dj)
}
// AddDialJob tries to take the needed tokens for starting the given dial job.
// If it acquires all needed tokens, it immediately starts the dial, otherwise
// it will put it on the waitlist for the requested token.
func (dl *dialLimiter) AddDialJob(dj *dialJob) {
dl.lk.Lock()
defer dl.lk.Unlock()
log.Debugf("[limiter] adding a dial job through limiter: %v", dj.addr)
dl.addCheckPeerLimit(dj)
}
func (dl *dialLimiter) clearAllPeerDials(p peer.ID) {
dl.lk.Lock()
defer dl.lk.Unlock()
delete(dl.waitingOnPeerLimit, p)
log.Debugf("[limiter] clearing all peer dials: %v", p)
// NB: the waitingOnFd list doesn't need to be cleaned out here, we will
// remove them as we encounter them because they are 'cancelled' at this
// point
}
// executeDial calls the dialFunc, and reports the result through the response
// channel when finished. Once the response is sent it also releases all tokens
// it held during the dial.
func (dl *dialLimiter) executeDial(j *dialJob) {
defer dl.finishedDial(j)
if j.cancelled() {
return
}
dctx, cancel := context.WithTimeout(j.ctx, j.timeout)
defer cancel()
con, err := dl.dialFunc(dctx, j.peer, j.addr)
select {
case j.resp <- dialResult{Conn: con, Addr: j.addr, Err: err}:
case <-j.ctx.Done():
if con != nil {
con.Close()
}
}
}

View File

@@ -0,0 +1,754 @@
package swarm
import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/metrics"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
)
const (
defaultDialTimeout = 15 * time.Second
// defaultDialTimeoutLocal is the maximum duration a Dial to local network address
// is allowed to take.
// This includes the time between dialing the raw network connection,
// protocol selection as well the handshake, if applicable.
defaultDialTimeoutLocal = 5 * time.Second
)
var log = logging.Logger("swarm2")
// ErrSwarmClosed is returned when one attempts to operate on a closed swarm.
var ErrSwarmClosed = errors.New("swarm closed")
// ErrAddrFiltered is returned when trying to register a connection to a
// filtered address. You shouldn't see this error unless some underlying
// transport is misbehaving.
var ErrAddrFiltered = errors.New("address filtered")
// ErrDialTimeout is returned when one a dial times out due to the global timeout
var ErrDialTimeout = errors.New("dial timed out")
type Option func(*Swarm) error
// WithConnectionGater sets a connection gater
func WithConnectionGater(gater connmgr.ConnectionGater) Option {
return func(s *Swarm) error {
s.gater = gater
return nil
}
}
// WithMultiaddrResolver sets a custom multiaddress resolver
func WithMultiaddrResolver(maResolver *madns.Resolver) Option {
return func(s *Swarm) error {
s.maResolver = maResolver
return nil
}
}
// WithMetrics sets a metrics reporter
func WithMetrics(reporter metrics.Reporter) Option {
return func(s *Swarm) error {
s.bwc = reporter
return nil
}
}
func WithMetricsTracer(t MetricsTracer) Option {
return func(s *Swarm) error {
s.metricsTracer = t
return nil
}
}
func WithDialTimeout(t time.Duration) Option {
return func(s *Swarm) error {
s.dialTimeout = t
return nil
}
}
func WithDialTimeoutLocal(t time.Duration) Option {
return func(s *Swarm) error {
s.dialTimeoutLocal = t
return nil
}
}
func WithResourceManager(m network.ResourceManager) Option {
return func(s *Swarm) error {
s.rcmgr = m
return nil
}
}
// WithDialRanker configures swarm to use d as the DialRanker
func WithDialRanker(d network.DialRanker) Option {
return func(s *Swarm) error {
if d == nil {
return errors.New("swarm: dial ranker cannot be nil")
}
s.dialRanker = d
return nil
}
}
// WithUDPBlackHoleConfig configures swarm to use c as the config for UDP black hole detection
// n is the size of the sliding window used to evaluate black hole state
// min is the minimum number of successes out of n required to not block requests
func WithUDPBlackHoleConfig(enabled bool, n, min int) Option {
return func(s *Swarm) error {
s.udpBlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min}
return nil
}
}
// WithIPv6BlackHoleConfig configures swarm to use c as the config for IPv6 black hole detection
// n is the size of the sliding window used to evaluate black hole state
// min is the minimum number of successes out of n required to not block requests
func WithIPv6BlackHoleConfig(enabled bool, n, min int) Option {
return func(s *Swarm) error {
s.ipv6BlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min}
return nil
}
}
// Swarm is a connection muxer, allowing connections to other peers to
// be opened and closed, while still using the same Chan for all
// communication. The Chan sends/receives Messages, which note the
// destination or source Peer.
type Swarm struct {
nextConnID uint64 // guarded by atomic
nextStreamID uint64 // guarded by atomic
// Close refcount. This allows us to fully wait for the swarm to be torn
// down before continuing.
refs sync.WaitGroup
emitter event.Emitter
rcmgr network.ResourceManager
local peer.ID
peers peerstore.Peerstore
dialTimeout time.Duration
dialTimeoutLocal time.Duration
conns struct {
sync.RWMutex
m map[peer.ID][]*Conn
}
listeners struct {
sync.RWMutex
ifaceListenAddres []ma.Multiaddr
cacheEOL time.Time
m map[transport.Listener]struct{}
}
notifs struct {
sync.RWMutex
m map[network.Notifiee]struct{}
}
transports struct {
sync.RWMutex
m map[int]transport.Transport
}
maResolver *madns.Resolver
// stream handlers
streamh atomic.Pointer[network.StreamHandler]
// dialing helpers
dsync *dialSync
backf DialBackoff
limiter *dialLimiter
gater connmgr.ConnectionGater
closeOnce sync.Once
ctx context.Context // is canceled when Close is called
ctxCancel context.CancelFunc
bwc metrics.Reporter
metricsTracer MetricsTracer
dialRanker network.DialRanker
udpBlackHoleConfig blackHoleConfig
ipv6BlackHoleConfig blackHoleConfig
bhd *blackHoleDetector
}
// NewSwarm constructs a Swarm.
func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts ...Option) (*Swarm, error) {
emitter, err := eventBus.Emitter(new(event.EvtPeerConnectednessChanged))
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
s := &Swarm{
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
maResolver: madns.DefaultResolver,
dialRanker: DefaultDialRanker,
// A black hole is a binary property. On a network if UDP dials are blocked or there is
// no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials
// is good enough.
udpBlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5},
ipv6BlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5},
}
s.conns.m = make(map[peer.ID][]*Conn)
s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport)
s.notifs.m = make(map[network.Notifiee]struct{})
for _, opt := range opts {
if err := opt(s); err != nil {
return nil, err
}
}
if s.rcmgr == nil {
s.rcmgr = &network.NullResourceManager{}
}
s.dsync = newDialSync(s.dialWorkerLoop)
s.limiter = newDialLimiter(s.dialAddr)
s.backf.init(s.ctx)
s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer)
return s, nil
}
func (s *Swarm) Close() error {
s.closeOnce.Do(s.close)
return nil
}
func (s *Swarm) close() {
s.ctxCancel()
s.emitter.Close()
// Prevents new connections and/or listeners from being added to the swarm.
s.listeners.Lock()
listeners := s.listeners.m
s.listeners.m = nil
s.listeners.Unlock()
s.conns.Lock()
conns := s.conns.m
s.conns.m = nil
s.conns.Unlock()
// Lots of goroutines but we might as well do this in parallel. We want to shut down as fast as
// possible.
for l := range listeners {
go func(l transport.Listener) {
if err := l.Close(); err != nil && err != transport.ErrListenerClosed {
log.Errorf("error when shutting down listener: %s", err)
}
}(l)
}
for _, cs := range conns {
for _, c := range cs {
go func(c *Conn) {
if err := c.Close(); err != nil {
log.Errorf("error when shutting down connection: %s", err)
}
}(c)
}
}
// Wait for everything to finish.
s.refs.Wait()
// Now close out any transports (if necessary). Do this after closing
// all connections/listeners.
s.transports.Lock()
transports := s.transports.m
s.transports.m = nil
s.transports.Unlock()
// Dedup transports that may be listening on multiple protocols
transportsToClose := make(map[transport.Transport]struct{}, len(transports))
for _, t := range transports {
transportsToClose[t] = struct{}{}
}
var wg sync.WaitGroup
for t := range transportsToClose {
if closer, ok := t.(io.Closer); ok {
wg.Add(1)
go func(c io.Closer) {
defer wg.Done()
if err := closer.Close(); err != nil {
log.Errorf("error when closing down transport %T: %s", c, err)
}
}(closer)
}
}
wg.Wait()
}
func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) {
var (
p = tc.RemotePeer()
addr = tc.RemoteMultiaddr()
)
// create the Stat object, initializing with the underlying connection Stat if available
var stat network.ConnStats
if cs, ok := tc.(network.ConnStat); ok {
stat = cs.Stat()
}
stat.Direction = dir
stat.Opened = time.Now()
// Wrap and register the connection.
c := &Conn{
conn: tc,
swarm: s,
stat: stat,
id: atomic.AddUint64(&s.nextConnID, 1),
}
// we ONLY check upgraded connections here so we can send them a Disconnect message.
// If we do this in the Upgrader, we will not be able to do this.
if s.gater != nil {
if allow, _ := s.gater.InterceptUpgraded(c); !allow {
// TODO Send disconnect with reason here
err := tc.Close()
if err != nil {
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err)
}
return nil, ErrGaterDisallowedConnection
}
}
// Add the public key.
if pk := tc.RemotePublicKey(); pk != nil {
s.peers.AddPubKey(p, pk)
}
// Clear any backoffs
s.backf.Clear(p)
// Finally, add the peer.
s.conns.Lock()
// Check if we're still online
if s.conns.m == nil {
s.conns.Unlock()
tc.Close()
return nil, ErrSwarmClosed
}
c.streams.m = make(map[*Stream]struct{})
isFirstConnection := len(s.conns.m[p]) == 0
s.conns.m[p] = append(s.conns.m[p], c)
// Add two swarm refs:
// * One will be decremented after the close notifications fire in Conn.doClose
// * The other will be decremented when Conn.start exits.
s.refs.Add(2)
// Take the notification lock before releasing the conns lock to block
// Disconnect notifications until after the Connect notifications done.
c.notifyLk.Lock()
s.conns.Unlock()
// Emit event after releasing `s.conns` lock so that a consumer can still
// use swarm methods that need the `s.conns` lock.
if isFirstConnection {
s.emitter.Emit(event.EvtPeerConnectednessChanged{
Peer: p,
Connectedness: network.Connected,
})
}
s.notifyAll(func(f network.Notifiee) {
f.Connected(s, c)
})
c.notifyLk.Unlock()
c.start()
return c, nil
}
// Peerstore returns this swarms internal Peerstore.
func (s *Swarm) Peerstore() peerstore.Peerstore {
return s.peers
}
// SetStreamHandler assigns the handler for new streams.
func (s *Swarm) SetStreamHandler(handler network.StreamHandler) {
s.streamh.Store(&handler)
}
// StreamHandler gets the handler for new streams.
func (s *Swarm) StreamHandler() network.StreamHandler {
handler := s.streamh.Load()
if handler == nil {
return nil
}
return *handler
}
// NewStream creates a new stream on any available connection to peer, dialing
// if necessary.
func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) {
log.Debugf("[%s] opening stream to peer [%s]", s.local, p)
// Algorithm:
// 1. Find the best connection, otherwise, dial.
// 2. Try opening a stream.
// 3. If the underlying connection is, in fact, closed, close the outer
// connection and try again. We do this in case we have a closed
// connection but don't notice it until we actually try to open a
// stream.
//
// Note: We only dial once.
//
// TODO: Try all connections even if we get an error opening a stream on
// a non-closed connection.
dials := 0
for {
// will prefer direct connections over relayed connections for opening streams
c, err := s.bestAcceptableConnToPeer(ctx, p)
if err != nil {
return nil, err
}
if c == nil {
if nodial, _ := network.GetNoDial(ctx); nodial {
return nil, network.ErrNoConn
}
if dials >= DialAttempts {
return nil, errors.New("max dial attempts exceeded")
}
dials++
var err error
c, err = s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
}
s, err := c.NewStream(ctx)
if err != nil {
if c.conn.IsClosed() {
continue
}
return nil, err
}
return s, nil
}
}
// ConnsToPeer returns all the live connections to peer.
func (s *Swarm) ConnsToPeer(p peer.ID) []network.Conn {
// TODO: Consider sorting the connection list best to worst. Currently,
// it's sorted oldest to newest.
s.conns.RLock()
defer s.conns.RUnlock()
conns := s.conns.m[p]
output := make([]network.Conn, len(conns))
for i, c := range conns {
output[i] = c
}
return output
}
func isBetterConn(a, b *Conn) bool {
// If one is transient and not the other, prefer the non-transient connection.
aTransient := a.Stat().Transient
bTransient := b.Stat().Transient
if aTransient != bTransient {
return !aTransient
}
// If one is direct and not the other, prefer the direct connection.
aDirect := isDirectConn(a)
bDirect := isDirectConn(b)
if aDirect != bDirect {
return aDirect
}
// Otherwise, prefer the connection with more open streams.
a.streams.Lock()
aLen := len(a.streams.m)
a.streams.Unlock()
b.streams.Lock()
bLen := len(b.streams.m)
b.streams.Unlock()
if aLen != bLen {
return aLen > bLen
}
// finally, pick the last connection.
return true
}
// bestConnToPeer returns the best connection to peer.
func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {
// TODO: Prefer some transports over others.
// For now, prefers direct connections over Relayed connections.
// For tie-breaking, select the newest non-closed connection with the most streams.
s.conns.RLock()
defer s.conns.RUnlock()
var best *Conn
for _, c := range s.conns.m[p] {
if c.conn.IsClosed() {
// We *will* garbage collect this soon anyways.
continue
}
if best == nil || isBetterConn(c, best) {
best = c
}
}
return best
}
// - Returns the best "acceptable" connection, if available.
// - Returns nothing if no such connection exists, but if we should try dialing anyways.
// - Returns an error if no such connection exists, but we should not try dialing.
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) (*Conn, error) {
conn := s.bestConnToPeer(p)
if conn == nil {
return nil, nil
}
forceDirect, _ := network.GetForceDirectDial(ctx)
if forceDirect && !isDirectConn(conn) {
return nil, nil
}
useTransient, _ := network.GetUseTransient(ctx)
if useTransient || !conn.Stat().Transient {
return conn, nil
}
return nil, network.ErrTransientConn
}
func isDirectConn(c *Conn) bool {
return c != nil && !c.conn.Transport().Proxy()
}
// Connectedness returns our "connectedness" state with the given peer.
//
// To check if we have an open connection, use `s.Connectedness(p) ==
// network.Connected`.
func (s *Swarm) Connectedness(p peer.ID) network.Connectedness {
if s.bestConnToPeer(p) != nil {
return network.Connected
}
return network.NotConnected
}
// Conns returns a slice of all connections.
func (s *Swarm) Conns() []network.Conn {
s.conns.RLock()
defer s.conns.RUnlock()
conns := make([]network.Conn, 0, len(s.conns.m))
for _, cs := range s.conns.m {
for _, c := range cs {
conns = append(conns, c)
}
}
return conns
}
// ClosePeer closes all connections to the given peer.
func (s *Swarm) ClosePeer(p peer.ID) error {
conns := s.ConnsToPeer(p)
switch len(conns) {
case 0:
return nil
case 1:
return conns[0].Close()
default:
errCh := make(chan error)
for _, c := range conns {
go func(c network.Conn) {
errCh <- c.Close()
}(c)
}
var errs []string
for range conns {
err := <-errCh
if err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return fmt.Errorf("when disconnecting from peer %s: %s", p, strings.Join(errs, ", "))
}
return nil
}
}
// Peers returns a copy of the set of peers swarm is connected to.
func (s *Swarm) Peers() []peer.ID {
s.conns.RLock()
defer s.conns.RUnlock()
peers := make([]peer.ID, 0, len(s.conns.m))
for p := range s.conns.m {
peers = append(peers, p)
}
return peers
}
// LocalPeer returns the local peer swarm is associated to.
func (s *Swarm) LocalPeer() peer.ID {
return s.local
}
// Backoff returns the DialBackoff object for this swarm.
func (s *Swarm) Backoff() *DialBackoff {
return &s.backf
}
// notifyAll sends a signal to all Notifiees
func (s *Swarm) notifyAll(notify func(network.Notifiee)) {
s.notifs.RLock()
for f := range s.notifs.m {
notify(f)
}
s.notifs.RUnlock()
}
// Notify signs up Notifiee to receive signals when events happen
func (s *Swarm) Notify(f network.Notifiee) {
s.notifs.Lock()
s.notifs.m[f] = struct{}{}
s.notifs.Unlock()
}
// StopNotify unregisters Notifiee fromr receiving signals
func (s *Swarm) StopNotify(f network.Notifiee) {
s.notifs.Lock()
delete(s.notifs.m, f)
s.notifs.Unlock()
}
func (s *Swarm) removeConn(c *Conn) {
p := c.RemotePeer()
s.conns.Lock()
cs := s.conns.m[p]
if len(cs) == 1 {
delete(s.conns.m, p)
s.conns.Unlock()
// Emit event after releasing `s.conns` lock so that a consumer can still
// use swarm methods that need the `s.conns` lock.
s.emitter.Emit(event.EvtPeerConnectednessChanged{
Peer: p,
Connectedness: network.NotConnected,
})
return
}
defer s.conns.Unlock()
for i, ci := range cs {
if ci == c {
// NOTE: We're intentionally preserving order.
// This way, connections to a peer are always
// sorted oldest to newest.
copy(cs[i:], cs[i+1:])
cs[len(cs)-1] = nil
s.conns.m[p] = cs[:len(cs)-1]
break
}
}
}
// String returns a string representation of Network.
func (s *Swarm) String() string {
return fmt.Sprintf("<Swarm %s>", s.LocalPeer())
}
func (s *Swarm) ResourceManager() network.ResourceManager {
return s.rcmgr
}
// Swarm is a Network.
var _ network.Network = (*Swarm)(nil)
var _ transport.TransportNetwork = (*Swarm)(nil)
type connWithMetrics struct {
transport.CapableConn
opened time.Time
dir network.Direction
metricsTracer MetricsTracer
}
func wrapWithMetrics(capableConn transport.CapableConn, metricsTracer MetricsTracer, opened time.Time, dir network.Direction) connWithMetrics {
c := connWithMetrics{CapableConn: capableConn, opened: opened, dir: dir, metricsTracer: metricsTracer}
c.metricsTracer.OpenedConnection(c.dir, capableConn.RemotePublicKey(), capableConn.ConnState(), capableConn.LocalMultiaddr())
return c
}
func (c connWithMetrics) completedHandshake() {
c.metricsTracer.CompletedHandshake(time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
}
func (c connWithMetrics) Close() error {
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
return c.CapableConn.Close()
}
func (c connWithMetrics) Stat() network.ConnStats {
if cs, ok := c.CapableConn.(network.ConnStat); ok {
return cs.Stat()
}
return network.ConnStats{}
}
var _ network.ConnStat = connWithMetrics{}

View File

@@ -0,0 +1,72 @@
package swarm
import (
"time"
manet "github.com/multiformats/go-multiaddr/net"
ma "github.com/multiformats/go-multiaddr"
)
// ListenAddresses returns a list of addresses at which this swarm listens.
func (s *Swarm) ListenAddresses() []ma.Multiaddr {
s.listeners.RLock()
defer s.listeners.RUnlock()
return s.listenAddressesNoLock()
}
func (s *Swarm) listenAddressesNoLock() []ma.Multiaddr {
addrs := make([]ma.Multiaddr, 0, len(s.listeners.m)+10) // A bit extra so we may avoid an extra allocation in the for loop below.
for l := range s.listeners.m {
addrs = append(addrs, l.Multiaddr())
}
return addrs
}
const ifaceAddrsCacheDuration = 1 * time.Minute
// InterfaceListenAddresses returns a list of addresses at which this swarm
// listens. It expands "any interface" addresses (/ip4/0.0.0.0, /ip6/::) to
// use the known local interfaces.
func (s *Swarm) InterfaceListenAddresses() ([]ma.Multiaddr, error) {
s.listeners.RLock() // RLock start
ifaceListenAddres := s.listeners.ifaceListenAddres
isEOL := time.Now().After(s.listeners.cacheEOL)
s.listeners.RUnlock() // RLock end
if !isEOL {
// Cache is valid, clone the slice
return append(ifaceListenAddres[:0:0], ifaceListenAddres...), nil
}
// Cache is not valid
// Perfrom double checked locking
s.listeners.Lock() // Lock start
ifaceListenAddres = s.listeners.ifaceListenAddres
isEOL = time.Now().After(s.listeners.cacheEOL)
if isEOL {
// Cache is still invalid
listenAddres := s.listenAddressesNoLock()
if len(listenAddres) > 0 {
// We're actually listening on addresses.
var err error
ifaceListenAddres, err = manet.ResolveUnspecifiedAddresses(listenAddres, nil)
if err != nil {
s.listeners.Unlock() // Lock early exit
return nil, err
}
} else {
ifaceListenAddres = nil
}
s.listeners.ifaceListenAddres = ifaceListenAddres
s.listeners.cacheEOL = time.Now().Add(ifaceAddrsCacheDuration)
}
s.listeners.Unlock() // Lock end
return append(ifaceListenAddres[:0:0], ifaceListenAddres...), nil
}

View File

@@ -0,0 +1,267 @@
package swarm
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
)
// TODO: Put this elsewhere.
// ErrConnClosed is returned when operating on a closed connection.
var ErrConnClosed = errors.New("connection closed")
// Conn is the connection type used by swarm. In general, you won't use this
// type directly.
type Conn struct {
id uint64
conn transport.CapableConn
swarm *Swarm
closeOnce sync.Once
err error
notifyLk sync.Mutex
streams struct {
sync.Mutex
m map[*Stream]struct{}
}
stat network.ConnStats
}
var _ network.Conn = &Conn{}
func (c *Conn) IsClosed() bool {
return c.conn.IsClosed()
}
func (c *Conn) ID() string {
// format: <first 10 chars of peer id>-<global conn ordinal>
return fmt.Sprintf("%s-%d", c.RemotePeer().Pretty()[0:10], c.id)
}
// Close closes this connection.
//
// Note: This method won't wait for the close notifications to finish as that
// would create a deadlock when called from an open notification (because all
// open notifications must finish before we can fire off the close
// notifications).
func (c *Conn) Close() error {
c.closeOnce.Do(c.doClose)
return c.err
}
func (c *Conn) doClose() {
c.swarm.removeConn(c)
// Prevent new streams from opening.
c.streams.Lock()
streams := c.streams.m
c.streams.m = nil
c.streams.Unlock()
c.err = c.conn.Close()
// This is just for cleaning up state. The connection has already been closed.
// We *could* optimize this but it really isn't worth it.
for s := range streams {
s.Reset()
}
// do this in a goroutine to avoid deadlocking if we call close in an open notification.
go func() {
// prevents us from issuing close notifications before finishing the open notifications
c.notifyLk.Lock()
defer c.notifyLk.Unlock()
c.swarm.notifyAll(func(f network.Notifiee) {
f.Disconnected(c.swarm, c)
})
c.swarm.refs.Done() // taken in Swarm.addConn
}()
}
func (c *Conn) removeStream(s *Stream) {
c.streams.Lock()
c.stat.NumStreams--
delete(c.streams.m, s)
c.streams.Unlock()
s.scope.Done()
}
// listens for new streams.
//
// The caller must take a swarm ref before calling. This function decrements the
// swarm ref count.
func (c *Conn) start() {
go func() {
defer c.swarm.refs.Done()
defer c.Close()
for {
ts, err := c.conn.AcceptStream()
if err != nil {
return
}
scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound)
if err != nil {
ts.Reset()
continue
}
c.swarm.refs.Add(1)
go func() {
s, err := c.addStream(ts, network.DirInbound, scope)
// Don't defer this. We don't want to block
// swarm shutdown on the connection handler.
c.swarm.refs.Done()
// We only get an error here when the swarm is closed or closing.
if err != nil {
scope.Done()
return
}
if h := c.swarm.StreamHandler(); h != nil {
h(s)
}
}()
}
}()
}
func (c *Conn) String() string {
return fmt.Sprintf(
"<swarm.Conn[%T] %s (%s) <-> %s (%s)>",
c.conn.Transport(),
c.conn.LocalMultiaddr(),
c.conn.LocalPeer().Pretty(),
c.conn.RemoteMultiaddr(),
c.conn.RemotePeer().Pretty(),
)
}
// LocalMultiaddr is the Multiaddr on this side
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
return c.conn.LocalMultiaddr()
}
// LocalPeer is the Peer on our side of the connection
func (c *Conn) LocalPeer() peer.ID {
return c.conn.LocalPeer()
}
// RemoteMultiaddr is the Multiaddr on the remote side
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
return c.conn.RemoteMultiaddr()
}
// RemotePeer is the Peer on the remote side
func (c *Conn) RemotePeer() peer.ID {
return c.conn.RemotePeer()
}
// RemotePublicKey is the public key of the peer on the remote side
func (c *Conn) RemotePublicKey() ic.PubKey {
return c.conn.RemotePublicKey()
}
// ConnState is the security connection state. including early data result.
// Empty if not supported.
func (c *Conn) ConnState() network.ConnectionState {
return c.conn.ConnState()
}
// Stat returns metadata pertaining to this connection
func (c *Conn) Stat() network.ConnStats {
c.streams.Lock()
defer c.streams.Unlock()
return c.stat
}
// NewStream returns a new Stream from this connection
func (c *Conn) NewStream(ctx context.Context) (network.Stream, error) {
if c.Stat().Transient {
if useTransient, _ := network.GetUseTransient(ctx); !useTransient {
return nil, network.ErrTransientConn
}
}
scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirOutbound)
if err != nil {
return nil, err
}
s, err := c.openAndAddStream(ctx, scope)
if err != nil {
scope.Done()
return nil, err
}
return s, nil
}
func (c *Conn) openAndAddStream(ctx context.Context, scope network.StreamManagementScope) (network.Stream, error) {
ts, err := c.conn.OpenStream(ctx)
if err != nil {
return nil, err
}
return c.addStream(ts, network.DirOutbound, scope)
}
func (c *Conn) addStream(ts network.MuxedStream, dir network.Direction, scope network.StreamManagementScope) (*Stream, error) {
c.streams.Lock()
// Are we still online?
if c.streams.m == nil {
c.streams.Unlock()
ts.Reset()
return nil, ErrConnClosed
}
// Wrap and register the stream.
s := &Stream{
stream: ts,
conn: c,
scope: scope,
stat: network.Stats{
Direction: dir,
Opened: time.Now(),
},
id: atomic.AddUint64(&c.swarm.nextStreamID, 1),
}
c.stat.NumStreams++
c.streams.m[s] = struct{}{}
// Released once the stream disconnect notifications have finished
// firing (in Swarm.remove).
c.swarm.refs.Add(1)
c.streams.Unlock()
return s, nil
}
// GetStreams returns the streams associated with this connection.
func (c *Conn) GetStreams() []network.Stream {
c.streams.Lock()
defer c.streams.Unlock()
streams := make([]network.Stream, 0, len(c.streams.m))
for s := range c.streams.m {
streams = append(streams, s)
}
return streams
}
func (c *Conn) Scope() network.ConnScope {
return c.conn.Scope()
}

View File

@@ -0,0 +1,626 @@
package swarm
import (
"context"
"errors"
"fmt"
"net/netip"
"strconv"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/canonicallog"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
)
// The maximum number of address resolution steps we'll perform for a single
// peer (for all addresses).
const maxAddressResolution = 32
// Diagram of dial sync:
//
// many callers of Dial() synched w. dials many addrs results to callers
// ----------------------\ dialsync use earliest /--------------
// -----------------------\ |----------\ /----------------
// ------------------------>------------<------- >---------<-----------------
// -----------------------| \----x \----------------
// ----------------------| \-----x \---------------
// any may fail if no addr at end
// retry dialAttempt x
var (
// ErrDialBackoff is returned by the backoff code when a given peer has
// been dialed too frequently
ErrDialBackoff = errors.New("dial backoff")
// ErrDialRefusedBlackHole is returned when we are in a black holed environment
ErrDialRefusedBlackHole = errors.New("dial refused because of black hole")
// ErrDialToSelf is returned if we attempt to dial our own peer
ErrDialToSelf = errors.New("dial to self attempted")
// ErrNoTransport is returned when we don't know a transport for the
// given multiaddr.
ErrNoTransport = errors.New("no transport for protocol")
// ErrAllDialsFailed is returned when connecting to a peer has ultimately failed
ErrAllDialsFailed = errors.New("all dials failed")
// ErrNoAddresses is returned when we fail to find any addresses for a
// peer we're trying to dial.
ErrNoAddresses = errors.New("no addresses")
// ErrNoGoodAddresses is returned when we find addresses for a peer but
// can't use any of them.
ErrNoGoodAddresses = errors.New("no good addresses")
// ErrGaterDisallowedConnection is returned when the gater prevents us from
// forming a connection with a peer.
ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer")
)
// DialAttempts governs how many times a goroutine will try to dial a given peer.
// Note: this is down to one, as we have _too many dials_ atm. To add back in,
// add loop back in Dial(.)
const DialAttempts = 1
// ConcurrentFdDials is the number of concurrent outbound dials over transports
// that consume file descriptors
const ConcurrentFdDials = 160
// DefaultPerPeerRateLimit is the number of concurrent outbound dials to make
// per peer
var DefaultPerPeerRateLimit = 8
// DialBackoff is a type for tracking peer dial backoffs. Dialbackoff is used to
// avoid over-dialing the same, dead peers. Whenever we totally time out on all
// addresses of a peer, we add the addresses to DialBackoff. Then, whenever we
// attempt to dial the peer again, we check each address for backoff. If it's on
// backoff, we don't dial the address and exit promptly. If a dial is
// successful, the peer and all its addresses are removed from backoff.
//
// * It's safe to use its zero value.
// * It's thread-safe.
// * It's *not* safe to move this type after using.
type DialBackoff struct {
entries map[peer.ID]map[string]*backoffAddr
lock sync.RWMutex
}
type backoffAddr struct {
tries int
until time.Time
}
func (db *DialBackoff) init(ctx context.Context) {
if db.entries == nil {
db.entries = make(map[peer.ID]map[string]*backoffAddr)
}
go db.background(ctx)
}
func (db *DialBackoff) background(ctx context.Context) {
ticker := time.NewTicker(BackoffMax)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
db.cleanup()
}
}
}
// Backoff returns whether the client should backoff from dialing
// peer p at address addr
func (db *DialBackoff) Backoff(p peer.ID, addr ma.Multiaddr) (backoff bool) {
db.lock.RLock()
defer db.lock.RUnlock()
ap, found := db.entries[p][string(addr.Bytes())]
return found && time.Now().Before(ap.until)
}
// BackoffBase is the base amount of time to backoff (default: 5s).
var BackoffBase = time.Second * 5
// BackoffCoef is the backoff coefficient (default: 1s).
var BackoffCoef = time.Second
// BackoffMax is the maximum backoff time (default: 5m).
var BackoffMax = time.Minute * 5
// AddBackoff adds peer's address to backoff.
//
// Backoff is not exponential, it's quadratic and computed according to the
// following formula:
//
// BackoffBase + BakoffCoef * PriorBackoffs^2
//
// Where PriorBackoffs is the number of previous backoffs.
func (db *DialBackoff) AddBackoff(p peer.ID, addr ma.Multiaddr) {
saddr := string(addr.Bytes())
db.lock.Lock()
defer db.lock.Unlock()
bp, ok := db.entries[p]
if !ok {
bp = make(map[string]*backoffAddr, 1)
db.entries[p] = bp
}
ba, ok := bp[saddr]
if !ok {
bp[saddr] = &backoffAddr{
tries: 1,
until: time.Now().Add(BackoffBase),
}
return
}
backoffTime := BackoffBase + BackoffCoef*time.Duration(ba.tries*ba.tries)
if backoffTime > BackoffMax {
backoffTime = BackoffMax
}
ba.until = time.Now().Add(backoffTime)
ba.tries++
}
// Clear removes a backoff record. Clients should call this after a
// successful Dial.
func (db *DialBackoff) Clear(p peer.ID) {
db.lock.Lock()
defer db.lock.Unlock()
delete(db.entries, p)
}
func (db *DialBackoff) cleanup() {
db.lock.Lock()
defer db.lock.Unlock()
now := time.Now()
for p, e := range db.entries {
good := false
for _, backoff := range e {
backoffTime := BackoffBase + BackoffCoef*time.Duration(backoff.tries*backoff.tries)
if backoffTime > BackoffMax {
backoffTime = BackoffMax
}
if now.Before(backoff.until.Add(backoffTime)) {
good = true
break
}
}
if !good {
delete(db.entries, p)
}
}
}
// DialPeer connects to a peer.
//
// The idea is that the client of Swarm does not need to know what network
// the connection will happen over. Swarm can use whichever it choses.
// This allows us to use various transport protocols, do NAT traversal/relay,
// etc. to achieve connection.
func (s *Swarm) DialPeer(ctx context.Context, p peer.ID) (network.Conn, error) {
// Avoid typed nil issues.
c, err := s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
return c, nil
}
// internal dial method that returns an unwrapped conn
//
// It is gated by the swarm's dial synchronization systems: dialsync and
// dialbackoff.
func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
log.Debugw("dialing peer", "from", s.local, "to", p)
err := p.Validate()
if err != nil {
return nil, err
}
if p == s.local {
return nil, ErrDialToSelf
}
// check if we already have an open (usable) connection first, or can't have a usable
// connection.
conn, err := s.bestAcceptableConnToPeer(ctx, p)
if conn != nil || err != nil {
return conn, err
}
if s.gater != nil && !s.gater.InterceptPeerDial(p) {
log.Debugf("gater disallowed outbound connection to peer %s", p.Pretty())
return nil, &DialError{Peer: p, Cause: ErrGaterDisallowedConnection}
}
// apply the DialPeer timeout
ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx))
defer cancel()
conn, err = s.dsync.Dial(ctx, p)
if err == nil {
// Ensure we connected to the correct peer.
// This was most likely already checked by the security protocol, but it doesn't hurt do it again here.
if conn.RemotePeer() != p {
conn.Close()
log.Errorw("Handshake failed to properly authenticate peer", "authenticated", conn.RemotePeer(), "expected", p)
return nil, fmt.Errorf("unexpected peer")
}
return conn, nil
}
log.Debugf("network for %s finished dialing %s", s.local, p)
if ctx.Err() != nil {
// Context error trumps any dial errors as it was likely the ultimate cause.
return nil, ctx.Err()
}
if s.ctx.Err() != nil {
// Ok, so the swarm is shutting down.
return nil, ErrSwarmClosed
}
return nil, err
}
// dialWorkerLoop synchronizes and executes concurrent dials to a single peer
func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) {
w := newDialWorker(s, p, reqch, nil)
w.loop()
}
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
peerAddrs := s.peers.Addrs(p)
if len(peerAddrs) == 0 {
return nil, ErrNoAddresses
}
peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs))
for _, a := range peerAddrs {
tpt := s.TransportForDialing(a)
resolver, ok := tpt.(transport.Resolver)
if ok {
resolvedAddrs, err := resolver.Resolve(ctx, a)
if err != nil {
log.Warnf("Failed to resolve multiaddr %s by transport %v: %v", a, tpt, err)
continue
}
peerAddrsAfterTransportResolved = append(peerAddrsAfterTransportResolved, resolvedAddrs...)
} else {
peerAddrsAfterTransportResolved = append(peerAddrsAfterTransportResolved, a)
}
}
// Resolve dns or dnsaddrs
resolved, err := s.resolveAddrs(ctx, peer.AddrInfo{
ID: p,
Addrs: peerAddrsAfterTransportResolved,
})
if err != nil {
return nil, err
}
goodAddrs := s.filterKnownUndialables(p, resolved)
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
goodAddrs = ma.Unique(goodAddrs)
if len(goodAddrs) == 0 {
return nil, ErrNoGoodAddresses
}
s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL)
return goodAddrs, nil
}
func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) {
proto := ma.ProtocolWithCode(ma.P_P2P).Name
p2paddr, err := ma.NewMultiaddr("/" + proto + "/" + pi.ID.Pretty())
if err != nil {
return nil, err
}
resolveSteps := 0
// Recursively resolve all addrs.
//
// While the toResolve list is non-empty:
// * Pop an address off.
// * If the address is fully resolved, add it to the resolved list.
// * Otherwise, resolve it and add the results to the "to resolve" list.
toResolve := append(([]ma.Multiaddr)(nil), pi.Addrs...)
resolved := make([]ma.Multiaddr, 0, len(pi.Addrs))
for len(toResolve) > 0 {
// pop the last addr off.
addr := toResolve[len(toResolve)-1]
toResolve = toResolve[:len(toResolve)-1]
// if it's resolved, add it to the resolved list.
if !madns.Matches(addr) {
resolved = append(resolved, addr)
continue
}
resolveSteps++
// We've resolved too many addresses. We can keep all the fully
// resolved addresses but we'll need to skip the rest.
if resolveSteps >= maxAddressResolution {
log.Warnf(
"peer %s asked us to resolve too many addresses: %s/%s",
pi.ID,
resolveSteps,
maxAddressResolution,
)
continue
}
// otherwise, resolve it
reqaddr := addr.Encapsulate(p2paddr)
resaddrs, err := s.maResolver.Resolve(ctx, reqaddr)
if err != nil {
log.Infof("error resolving %s: %s", reqaddr, err)
}
// add the results to the toResolve list.
for _, res := range resaddrs {
pi, err := peer.AddrInfoFromP2pAddr(res)
if err != nil {
log.Infof("error parsing %s: %s", res, err)
}
toResolve = append(toResolve, pi.Addrs...)
}
}
return resolved, nil
}
func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error {
// check the dial backoff
if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect {
if s.backf.Backoff(p, addr) {
return ErrDialBackoff
}
}
// start the dial
s.limitedDial(ctx, p, addr, resch)
return nil
}
func (s *Swarm) canDial(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return t != nil && t.CanDial(addr)
}
func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return !t.Proxy()
}
// filterKnownUndialables takes a list of multiaddrs, and removes those
// that we definitely don't want to dial: addresses configured to be blocked,
// IPv6 link-local addresses, addresses without a dial-capable transport,
// addresses that we know to be our own, and addresses with a better tranport
// available. This is an optimization to avoid wasting time on dials that we
// know are going to fail or for which we have a better alternative.
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr {
lisAddrs, _ := s.InterfaceListenAddresses()
var ourAddrs []ma.Multiaddr
for _, addr := range lisAddrs {
// we're only sure about filtering out /ip4 and /ip6 addresses, so far
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 {
ourAddrs = append(ourAddrs, addr)
}
return false
})
}
// The order of these two filters is important. If we can only dial /webtransport,
// we don't want to filter /webtransport addresses out because the peer had a /quic-v1
// address
// filter addresses we cannot dial
addrs = ma.FilterAddrs(addrs, s.canDial)
// filter low priority addresses among the addresses we can dial
addrs = filterLowPriorityAddresses(addrs)
// remove black holed addrs
addrs = s.bhd.FilterAddrs(addrs)
return ma.FilterAddrs(addrs,
func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) },
// TODO: Consider allowing link-local addresses
func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) },
func(addr ma.Multiaddr) bool {
return s.gater == nil || s.gater.InterceptAddrDial(p, addr)
},
)
}
// limitedDial will start a dial to the given peer when
// it is able, respecting the various different types of rate
// limiting that occur without using extra goroutines per addr
func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp chan dialResult) {
timeout := s.dialTimeout
if lowTimeoutFilters.AddrBlocked(a) && s.dialTimeoutLocal < s.dialTimeout {
timeout = s.dialTimeoutLocal
}
s.limiter.AddDialJob(&dialJob{
addr: a,
peer: p,
resp: resp,
ctx: ctx,
timeout: timeout,
})
}
// dialAddr is the actual dial for an addr, indirectly invoked through the limiter
func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) {
// Just to double check. Costs nothing.
if s.local == p {
return nil, ErrDialToSelf
}
// Check before we start work
if err := ctx.Err(); err != nil {
log.Debugf("%s swarm not dialing. Context cancelled: %v. %s %s", s.local, err, p, addr)
return nil, err
}
log.Debugf("%s swarm dialing %s %s", s.local, p, addr)
tpt := s.TransportForDialing(addr)
if tpt == nil {
return nil, ErrNoTransport
}
start := time.Now()
connC, err := tpt.Dial(ctx, addr, p)
// We're recording any error as a failure here.
// Notably, this also applies to cancelations (i.e. if another dial attempt was faster).
// This is ok since the black hole detector uses a very low threshold (5%).
s.bhd.RecordResult(addr, err == nil)
if err != nil {
if s.metricsTracer != nil {
s.metricsTracer.FailedDialing(addr, err)
}
return nil, err
}
canonicallog.LogPeerStatus(100, connC.RemotePeer(), connC.RemoteMultiaddr(), "connection_status", "established", "dir", "outbound")
if s.metricsTracer != nil {
connWithMetrics := wrapWithMetrics(connC, s.metricsTracer, start, network.DirOutbound)
connWithMetrics.completedHandshake()
connC = connWithMetrics
}
// Trust the transport? Yeah... right.
if connC.RemotePeer() != p {
connC.Close()
err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), tpt)
log.Error(err)
return nil, err
}
// success! we got one!
return connC, nil
}
// TODO We should have a `IsFdConsuming() bool` method on the `Transport` interface in go-libp2p/core/transport.
// This function checks if any of the transport protocols in the address requires a file descriptor.
// For now:
// A Non-circuit address which has the TCP/UNIX protocol is deemed FD consuming.
// For a circuit-relay address, we look at the address of the relay server/proxy
// and use the same logic as above to decide.
func isFdConsumingAddr(addr ma.Multiaddr) bool {
first, _ := ma.SplitFunc(addr, func(c ma.Component) bool {
return c.Protocol().Code == ma.P_CIRCUIT
})
// for safety
if first == nil {
return true
}
_, err1 := first.ValueForProtocol(ma.P_TCP)
_, err2 := first.ValueForProtocol(ma.P_UNIX)
return err1 == nil || err2 == nil
}
func isRelayAddr(addr ma.Multiaddr) bool {
_, err := addr.ValueForProtocol(ma.P_CIRCUIT)
return err == nil
}
// filterLowPriorityAddresses removes addresses inplace for which we have a better alternative
// 1. If a /quic-v1 address is present, filter out /quic and /webtransport address on the same 2-tuple:
// QUIC v1 is preferred over the deprecated QUIC draft-29, and given the choice, we prefer using
// raw QUIC over using WebTransport.
// 2. If a /tcp address is present, filter out /ws or /wss addresses on the same 2-tuple:
// We prefer using raw TCP over using WebSocket.
func filterLowPriorityAddresses(addrs []ma.Multiaddr) []ma.Multiaddr {
// make a map of QUIC v1 and TCP AddrPorts.
quicV1Addr := make(map[netip.AddrPort]struct{})
tcpAddr := make(map[netip.AddrPort]struct{})
for _, a := range addrs {
switch {
case isProtocolAddr(a, ma.P_WEBTRANSPORT):
case isProtocolAddr(a, ma.P_QUIC_V1):
ap, err := addrPort(a, ma.P_UDP)
if err != nil {
continue
}
quicV1Addr[ap] = struct{}{}
case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS):
case isProtocolAddr(a, ma.P_TCP):
ap, err := addrPort(a, ma.P_TCP)
if err != nil {
continue
}
tcpAddr[ap] = struct{}{}
}
}
i := 0
for _, a := range addrs {
switch {
case isProtocolAddr(a, ma.P_WEBTRANSPORT) || isProtocolAddr(a, ma.P_QUIC):
ap, err := addrPort(a, ma.P_UDP)
if err != nil {
break
}
if _, ok := quicV1Addr[ap]; ok {
continue
}
case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS):
ap, err := addrPort(a, ma.P_TCP)
if err != nil {
break
}
if _, ok := tcpAddr[ap]; ok {
continue
}
}
addrs[i] = a
i++
}
return addrs[:i]
}
// addrPort returns the ip and port for a. p should be either ma.P_TCP or ma.P_UDP.
// a must be an (ip, TCP) or (ip, udp) address.
func addrPort(a ma.Multiaddr, p int) (netip.AddrPort, error) {
ip, err := manet.ToIP(a)
if err != nil {
return netip.AddrPort{}, err
}
port, err := a.ValueForProtocol(p)
if err != nil {
return netip.AddrPort{}, err
}
pi, err := strconv.Atoi(port)
if err != nil {
return netip.AddrPort{}, err
}
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.AddrPort{}, fmt.Errorf("failed to parse IP %s", ip)
}
return netip.AddrPortFrom(addr, uint16(pi)), nil
}

View File

@@ -0,0 +1,168 @@
package swarm
import (
"errors"
"fmt"
"time"
"github.com/libp2p/go-libp2p/core/canonicallog"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
)
// Listen sets up listeners for all of the given addresses.
// It returns as long as we successfully listen on at least *one* address.
func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
errs := make([]error, len(addrs))
var succeeded int
for i, a := range addrs {
if err := s.AddListenAddr(a); err != nil {
errs[i] = err
} else {
succeeded++
}
}
for i, e := range errs {
if e != nil {
log.Warnw("listening failed", "on", addrs[i], "error", errs[i])
}
}
if succeeded == 0 && len(addrs) > 0 {
return fmt.Errorf("failed to listen on any addresses: %s", errs)
}
return nil
}
// ListenClose stop and delete listeners for all of the given addresses. If an
// any address belongs to one of the addreses a Listener provides, then the
// Listener will close for *all* addresses it provides. For example if you close
// and address with `/quic`, then the QUIC listener will close and also close
// any `/quic-v1` address.
func (s *Swarm) ListenClose(addrs ...ma.Multiaddr) {
listenersToClose := make(map[transport.Listener]struct{}, len(addrs))
s.listeners.Lock()
for l := range s.listeners.m {
if !containsMultiaddr(addrs, l.Multiaddr()) {
continue
}
delete(s.listeners.m, l)
listenersToClose[l] = struct{}{}
}
s.listeners.cacheEOL = time.Time{}
s.listeners.Unlock()
for l := range listenersToClose {
l.Close()
}
}
// AddListenAddr tells the swarm to listen on a single address. Unlike Listen,
// this method does not attempt to filter out bad addresses.
func (s *Swarm) AddListenAddr(a ma.Multiaddr) error {
tpt := s.TransportForListening(a)
if tpt == nil {
// TransportForListening will return nil if either:
// 1. No transport has been registered.
// 2. We're closed (so we've nulled out the transport map.
//
// Distinguish between these two cases to avoid confusing users.
select {
case <-s.ctx.Done():
return ErrSwarmClosed
default:
return ErrNoTransport
}
}
list, err := tpt.Listen(a)
if err != nil {
return err
}
s.listeners.Lock()
if s.listeners.m == nil {
s.listeners.Unlock()
list.Close()
return ErrSwarmClosed
}
s.refs.Add(1)
s.listeners.m[list] = struct{}{}
s.listeners.cacheEOL = time.Time{}
s.listeners.Unlock()
maddr := list.Multiaddr()
// signal to our notifiees on listen.
s.notifyAll(func(n network.Notifiee) {
n.Listen(s, maddr)
})
go func() {
defer func() {
s.listeners.Lock()
_, ok := s.listeners.m[list]
if ok {
delete(s.listeners.m, list)
s.listeners.cacheEOL = time.Time{}
}
s.listeners.Unlock()
if ok {
list.Close()
log.Errorf("swarm listener unintentionally closed")
}
// signal to our notifiees on listen close.
s.notifyAll(func(n network.Notifiee) {
n.ListenClose(s, maddr)
})
s.refs.Done()
}()
for {
c, err := list.Accept()
if err != nil {
if !errors.Is(err, transport.ErrListenerClosed) {
log.Errorf("swarm listener for %s accept error: %s", a, err)
}
return
}
canonicallog.LogPeerStatus(100, c.RemotePeer(), c.RemoteMultiaddr(), "connection_status", "established", "dir", "inbound")
if s.metricsTracer != nil {
c = wrapWithMetrics(c, s.metricsTracer, time.Now(), network.DirInbound)
}
log.Debugf("swarm listener accepted connection: %s <-> %s", c.LocalMultiaddr(), c.RemoteMultiaddr())
s.refs.Add(1)
go func() {
defer s.refs.Done()
_, err := s.addConn(c, network.DirInbound)
switch err {
case nil:
case ErrSwarmClosed:
// ignore.
return
default:
log.Warnw("adding connection failed", "to", a, "error", err)
return
}
}()
}
}()
return nil
}
func containsMultiaddr(addrs []ma.Multiaddr, addr ma.Multiaddr) bool {
for _, a := range addrs {
if addr.Equal(a) {
return true
}
}
return false
}

View File

@@ -0,0 +1,277 @@
package swarm
import (
"context"
"errors"
"net"
"strings"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/metricshelper"
ma "github.com/multiformats/go-multiaddr"
"github.com/prometheus/client_golang/prometheus"
)
const metricNamespace = "libp2p_swarm"
var (
connsOpened = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "connections_opened_total",
Help: "Connections Opened",
},
[]string{"dir", "transport", "security", "muxer", "early_muxer", "ip_version"},
)
keyTypes = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "key_types_total",
Help: "key type",
},
[]string{"dir", "key_type"},
)
connsClosed = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "connections_closed_total",
Help: "Connections Closed",
},
[]string{"dir", "transport", "security", "muxer", "early_muxer", "ip_version"},
)
dialError = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "dial_errors_total",
Help: "Dial Error",
},
[]string{"transport", "error", "ip_version"},
)
connDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "connection_duration_seconds",
Help: "Duration of a Connection",
Buckets: prometheus.ExponentialBuckets(1.0/16, 2, 25), // up to 24 days
},
[]string{"dir", "transport", "security", "muxer", "early_muxer", "ip_version"},
)
connHandshakeLatency = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "handshake_latency_seconds",
Help: "Duration of the libp2p Handshake",
Buckets: prometheus.ExponentialBuckets(0.001, 1.3, 35),
},
[]string{"transport", "security", "muxer", "early_muxer", "ip_version"},
)
dialsPerPeer = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Name: "dials_per_peer_total",
Help: "Number of addresses dialed per peer",
},
[]string{"outcome", "num_dials"},
)
dialRankingDelay = prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: metricNamespace,
Name: "dial_ranking_delay_seconds",
Help: "delay introduced by the dial ranking logic",
Buckets: []float64{0.001, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1, 2},
},
)
blackHoleFilterState = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "black_hole_filter_state",
Help: "State of the black hole filter",
},
[]string{"name"},
)
blackHoleFilterSuccessFraction = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "black_hole_filter_success_fraction",
Help: "Fraction of successful dials among the last n requests",
},
[]string{"name"},
)
blackHoleFilterNextRequestAllowedAfter = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricNamespace,
Name: "black_hole_filter_next_request_allowed_after",
Help: "Number of requests after which the next request will be allowed",
},
[]string{"name"},
)
collectors = []prometheus.Collector{
connsOpened,
keyTypes,
connsClosed,
dialError,
connDuration,
connHandshakeLatency,
dialsPerPeer,
dialRankingDelay,
blackHoleFilterSuccessFraction,
blackHoleFilterState,
blackHoleFilterNextRequestAllowedAfter,
}
)
type MetricsTracer interface {
OpenedConnection(network.Direction, crypto.PubKey, network.ConnectionState, ma.Multiaddr)
ClosedConnection(network.Direction, time.Duration, network.ConnectionState, ma.Multiaddr)
CompletedHandshake(time.Duration, network.ConnectionState, ma.Multiaddr)
FailedDialing(ma.Multiaddr, error)
DialCompleted(success bool, totalDials int)
DialRankingDelay(d time.Duration)
UpdatedBlackHoleFilterState(name string, state blackHoleState, nextProbeAfter int, successFraction float64)
}
type metricsTracer struct{}
var _ MetricsTracer = &metricsTracer{}
type metricsTracerSetting struct {
reg prometheus.Registerer
}
type MetricsTracerOption func(*metricsTracerSetting)
func WithRegisterer(reg prometheus.Registerer) MetricsTracerOption {
return func(s *metricsTracerSetting) {
if reg != nil {
s.reg = reg
}
}
}
func NewMetricsTracer(opts ...MetricsTracerOption) MetricsTracer {
setting := &metricsTracerSetting{reg: prometheus.DefaultRegisterer}
for _, opt := range opts {
opt(setting)
}
metricshelper.RegisterCollectors(setting.reg, collectors...)
return &metricsTracer{}
}
func appendConnectionState(tags []string, cs network.ConnectionState) []string {
if cs.Transport == "" {
// This shouldn't happen, unless the transport doesn't properly set the Transport field in the ConnectionState.
tags = append(tags, "unknown")
} else {
tags = append(tags, string(cs.Transport))
}
// These might be empty, depending on the transport.
// For example, QUIC doesn't set security nor muxer.
tags = append(tags, string(cs.Security))
tags = append(tags, string(cs.StreamMultiplexer))
earlyMuxer := "false"
if cs.UsedEarlyMuxerNegotiation {
earlyMuxer = "true"
}
tags = append(tags, earlyMuxer)
return tags
}
func (m *metricsTracer) OpenedConnection(dir network.Direction, p crypto.PubKey, cs network.ConnectionState, laddr ma.Multiaddr) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, metricshelper.GetDirection(dir))
*tags = appendConnectionState(*tags, cs)
*tags = append(*tags, metricshelper.GetIPVersion(laddr))
connsOpened.WithLabelValues(*tags...).Inc()
*tags = (*tags)[:0]
*tags = append(*tags, metricshelper.GetDirection(dir))
*tags = append(*tags, p.Type().String())
keyTypes.WithLabelValues(*tags...).Inc()
}
func (m *metricsTracer) ClosedConnection(dir network.Direction, duration time.Duration, cs network.ConnectionState, laddr ma.Multiaddr) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, metricshelper.GetDirection(dir))
*tags = appendConnectionState(*tags, cs)
*tags = append(*tags, metricshelper.GetIPVersion(laddr))
connsClosed.WithLabelValues(*tags...).Inc()
connDuration.WithLabelValues(*tags...).Observe(duration.Seconds())
}
func (m *metricsTracer) CompletedHandshake(t time.Duration, cs network.ConnectionState, laddr ma.Multiaddr) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = appendConnectionState(*tags, cs)
*tags = append(*tags, metricshelper.GetIPVersion(laddr))
connHandshakeLatency.WithLabelValues(*tags...).Observe(t.Seconds())
}
func (m *metricsTracer) FailedDialing(addr ma.Multiaddr, err error) {
transport := metricshelper.GetTransport(addr)
e := "other"
if errors.Is(err, context.Canceled) {
e = "canceled"
} else if errors.Is(err, context.DeadlineExceeded) {
e = "deadline"
} else {
nerr, ok := err.(net.Error)
if ok && nerr.Timeout() {
e = "timeout"
} else if strings.Contains(err.Error(), "connect: connection refused") {
e = "connection refused"
}
}
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, transport, e)
*tags = append(*tags, metricshelper.GetIPVersion(addr))
dialError.WithLabelValues(*tags...).Inc()
}
func (m *metricsTracer) DialCompleted(success bool, totalDials int) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
if success {
*tags = append(*tags, "success")
} else {
*tags = append(*tags, "failed")
}
numDialLabels := [...]string{"0", "1", "2", "3", "4", "5", ">=6"}
var numDials string
if totalDials < len(numDialLabels) {
numDials = numDialLabels[totalDials]
} else {
numDials = numDialLabels[len(numDialLabels)-1]
}
*tags = append(*tags, numDials)
dialsPerPeer.WithLabelValues(*tags...).Inc()
}
func (m *metricsTracer) DialRankingDelay(d time.Duration) {
dialRankingDelay.Observe(d.Seconds())
}
func (m *metricsTracer) UpdatedBlackHoleFilterState(name string, state blackHoleState,
nextProbeAfter int, successFraction float64) {
tags := metricshelper.GetStringSlice()
defer metricshelper.PutStringSlice(tags)
*tags = append(*tags, name)
blackHoleFilterState.WithLabelValues(*tags...).Set(float64(state))
blackHoleFilterSuccessFraction.WithLabelValues(*tags...).Set(successFraction)
blackHoleFilterNextRequestAllowedAfter.WithLabelValues(*tags...).Set(float64(nextProbeAfter))
}

View File

@@ -0,0 +1,154 @@
package swarm
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
)
// Validate Stream conforms to the go-libp2p-net Stream interface
var _ network.Stream = &Stream{}
// Stream is the stream type used by swarm. In general, you won't use this type
// directly.
type Stream struct {
id uint64
stream network.MuxedStream
conn *Conn
scope network.StreamManagementScope
closeOnce sync.Once
protocol atomic.Pointer[protocol.ID]
stat network.Stats
}
func (s *Stream) ID() string {
// format: <first 10 chars of peer id>-<global conn ordinal>-<global stream ordinal>
return fmt.Sprintf("%s-%d", s.conn.ID(), s.id)
}
func (s *Stream) String() string {
return fmt.Sprintf(
"<swarm.Stream[%s] %s (%s) <-> %s (%s)>",
s.conn.conn.Transport(),
s.conn.LocalMultiaddr(),
s.conn.LocalPeer(),
s.conn.RemoteMultiaddr(),
s.conn.RemotePeer(),
)
}
// Conn returns the Conn associated with this stream, as an network.Conn
func (s *Stream) Conn() network.Conn {
return s.conn
}
// Read reads bytes from a stream.
func (s *Stream) Read(p []byte) (int, error) {
n, err := s.stream.Read(p)
// TODO: push this down to a lower level for better accuracy.
if s.conn.swarm.bwc != nil {
s.conn.swarm.bwc.LogRecvMessage(int64(n))
s.conn.swarm.bwc.LogRecvMessageStream(int64(n), s.Protocol(), s.Conn().RemotePeer())
}
return n, err
}
// Write writes bytes to a stream, flushing for each call.
func (s *Stream) Write(p []byte) (int, error) {
n, err := s.stream.Write(p)
// TODO: push this down to a lower level for better accuracy.
if s.conn.swarm.bwc != nil {
s.conn.swarm.bwc.LogSentMessage(int64(n))
s.conn.swarm.bwc.LogSentMessageStream(int64(n), s.Protocol(), s.Conn().RemotePeer())
}
return n, err
}
// Close closes the stream, closing both ends and freeing all associated
// resources.
func (s *Stream) Close() error {
err := s.stream.Close()
s.closeOnce.Do(s.remove)
return err
}
// Reset resets the stream, signaling an error on both ends and freeing all
// associated resources.
func (s *Stream) Reset() error {
err := s.stream.Reset()
s.closeOnce.Do(s.remove)
return err
}
// CloseWrite closes the stream for writing, flushing all data and sending an EOF.
// This function does not free resources, call Close or Reset when done with the
// stream.
func (s *Stream) CloseWrite() error {
return s.stream.CloseWrite()
}
// CloseRead closes the stream for reading. This function does not free resources,
// call Close or Reset when done with the stream.
func (s *Stream) CloseRead() error {
return s.stream.CloseRead()
}
func (s *Stream) remove() {
s.conn.removeStream(s)
s.conn.swarm.refs.Done()
}
// Protocol returns the protocol negotiated on this stream (if set).
func (s *Stream) Protocol() protocol.ID {
p := s.protocol.Load()
if p == nil {
return ""
}
return *p
}
// SetProtocol sets the protocol for this stream.
//
// This doesn't actually *do* anything other than record the fact that we're
// speaking the given protocol over this stream. It's still up to the user to
// negotiate the protocol. This is usually done by the Host.
func (s *Stream) SetProtocol(p protocol.ID) error {
if err := s.scope.SetProtocol(p); err != nil {
return err
}
s.protocol.Store(&p)
return nil
}
// SetDeadline sets the read and write deadlines for this stream.
func (s *Stream) SetDeadline(t time.Time) error {
return s.stream.SetDeadline(t)
}
// SetReadDeadline sets the read deadline for this stream.
func (s *Stream) SetReadDeadline(t time.Time) error {
return s.stream.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline for this stream.
func (s *Stream) SetWriteDeadline(t time.Time) error {
return s.stream.SetWriteDeadline(t)
}
// Stat returns metadata information for this stream.
func (s *Stream) Stat() network.Stats {
return s.stat
}
func (s *Stream) Scope() network.StreamScope {
return s.scope
}

View File

@@ -0,0 +1,109 @@
package swarm
import (
"fmt"
"strings"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
)
// TransportForDialing retrieves the appropriate transport for dialing the given
// multiaddr.
func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport {
protocols := a.Protocols()
if len(protocols) == 0 {
return nil
}
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}
if isRelayAddr(a) {
return s.transports.m[ma.P_CIRCUIT]
}
for _, t := range s.transports.m {
if t.CanDial(a) {
return t
}
}
return nil
}
// TransportForListening retrieves the appropriate transport for listening on
// the given multiaddr.
func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport {
protocols := a.Protocols()
if len(protocols) == 0 {
return nil
}
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}
selected := s.transports.m[protocols[len(protocols)-1].Code]
for _, p := range protocols {
transport, ok := s.transports.m[p.Code]
if !ok {
continue
}
if transport.Proxy() {
selected = transport
}
}
return selected
}
// AddTransport adds a transport to this swarm.
//
// Satisfies the Network interface from go-libp2p-transport.
func (s *Swarm) AddTransport(t transport.Transport) error {
protocols := t.Protocols()
if len(protocols) == 0 {
return fmt.Errorf("useless transport handles no protocols: %T", t)
}
s.transports.Lock()
defer s.transports.Unlock()
if s.transports.m == nil {
return ErrSwarmClosed
}
var registered []string
for _, p := range protocols {
if _, ok := s.transports.m[p]; ok {
proto := ma.ProtocolWithCode(p)
name := proto.Name
if name == "" {
name = fmt.Sprintf("unknown (%d)", p)
}
registered = append(registered, name)
}
}
if len(registered) > 0 {
return fmt.Errorf(
"transports already registered for protocol(s): %s",
strings.Join(registered, ", "),
)
}
for _, p := range protocols {
s.transports.m[p] = t
}
return nil
}

View File

@@ -0,0 +1,65 @@
package upgrader
import (
"fmt"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/transport"
)
type transportConn struct {
network.MuxedConn
network.ConnMultiaddrs
network.ConnSecurity
transport transport.Transport
scope network.ConnManagementScope
stat network.ConnStats
muxer protocol.ID
security protocol.ID
usedEarlyMuxerNegotiation bool
}
var _ transport.CapableConn = &transportConn{}
func (t *transportConn) Transport() transport.Transport {
return t.transport
}
func (t *transportConn) String() string {
ts := ""
if s, ok := t.transport.(fmt.Stringer); ok {
ts = "[" + s.String() + "]"
}
return fmt.Sprintf(
"<stream.Conn%s %s (%s) <-> %s (%s)>",
ts,
t.LocalMultiaddr(),
t.LocalPeer(),
t.RemoteMultiaddr(),
t.RemotePeer(),
)
}
func (t *transportConn) Stat() network.ConnStats {
return t.stat
}
func (t *transportConn) Scope() network.ConnScope {
return t.scope
}
func (t *transportConn) Close() error {
defer t.scope.Done()
return t.MuxedConn.Close()
}
func (t *transportConn) ConnState() network.ConnectionState {
return network.ConnectionState{
StreamMultiplexer: t.muxer,
Security: t.security,
Transport: "tcp",
UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation,
}
}

View File

@@ -0,0 +1,182 @@
package upgrader
import (
"context"
"fmt"
"strings"
"sync"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
logging "github.com/ipfs/go-log/v2"
tec "github.com/jbenet/go-temp-err-catcher"
manet "github.com/multiformats/go-multiaddr/net"
)
var log = logging.Logger("upgrader")
type listener struct {
manet.Listener
transport transport.Transport
upgrader *upgrader
rcmgr network.ResourceManager
incoming chan transport.CapableConn
err error
// Used for backpressure
threshold *threshold
// Canceling this context isn't sufficient to tear down the listener.
// Call close.
ctx context.Context
cancel func()
}
// Close closes the listener.
func (l *listener) Close() error {
// Do this first to try to get any relevent errors.
err := l.Listener.Close()
l.cancel()
// Drain and wait.
for c := range l.incoming {
c.Close()
}
return err
}
// handles inbound connections.
//
// This function does a few interesting things that should be noted:
//
// 1. It logs and discards temporary/transient errors (errors with a Temporary()
// function that returns true).
// 2. It stops accepting new connections once AcceptQueueLength connections have
// been fully negotiated but not accepted. This gives us a basic backpressure
// mechanism while still allowing us to negotiate connections in parallel.
func (l *listener) handleIncoming() {
var wg sync.WaitGroup
defer func() {
// make sure we're closed
l.Listener.Close()
if l.err == nil {
l.err = fmt.Errorf("listener closed")
}
wg.Wait()
close(l.incoming)
}()
var catcher tec.TempErrCatcher
for l.ctx.Err() == nil {
maconn, err := l.Listener.Accept()
if err != nil {
// Note: function may pause the accept loop.
if catcher.IsTemporary(err) {
log.Infof("temporary accept error: %s", err)
continue
}
l.err = err
return
}
catcher.Reset()
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
}
continue
}
// The go routine below calls Release when the context is
// canceled so there's no need to wait on it here.
l.threshold.Wait()
log.Debugf("listener %s got connection: %s <---> %s",
l,
maconn.LocalMultiaddr(),
maconn.RemoteMultiaddr())
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(l.ctx, l.upgrader.acceptTimeout)
defer cancel()
conn, err := l.upgrader.Upgrade(ctx, l.transport, maconn, network.DirInbound, "", connScope)
if err != nil {
// Don't bother bubbling this up. We just failed
// to completely negotiate the connection.
log.Debugf("accept upgrade error: %s (%s <--> %s)",
err,
maconn.LocalMultiaddr(),
maconn.RemoteMultiaddr())
connScope.Done()
return
}
log.Debugf("listener %s accepted connection: %s", l, conn)
// This records the fact that the connection has been
// setup and is waiting to be accepted. This call
// *never* blocks, even if we go over the threshold. It
// simply ensures that calls to Wait block while we're
// over the threshold.
l.threshold.Acquire()
defer l.threshold.Release()
select {
case l.incoming <- conn:
case <-ctx.Done():
if l.ctx.Err() == nil {
// Listener *not* closed but the accept timeout expired.
log.Warn("listener dropped connection due to slow accept")
}
// Wait on the context with a timeout. This way,
// if we stop accepting connections for some reason,
// we'll eventually close all the open ones
// instead of hanging onto them.
conn.Close()
}
}()
}
}
// Accept accepts a connection.
func (l *listener) Accept() (transport.CapableConn, error) {
for c := range l.incoming {
// Could have been sitting there for a while.
if !c.IsClosed() {
return c, nil
}
}
if strings.Contains(l.err.Error(), "use of closed network connection") {
return nil, transport.ErrListenerClosed
}
return nil, l.err
}
func (l *listener) String() string {
if s, ok := l.transport.(fmt.Stringer); ok {
return fmt.Sprintf("<stream.Listener[%s] %s>", s, l.Multiaddr())
}
return fmt.Sprintf("<stream.Listener %s>", l.Multiaddr())
}
var _ transport.Listener = (*listener)(nil)

View File

@@ -0,0 +1,50 @@
package upgrader
import (
"sync"
)
func newThreshold(cutoff int) *threshold {
t := &threshold{
threshold: cutoff,
}
t.cond.L = &t.mu
return t
}
type threshold struct {
mu sync.Mutex
cond sync.Cond
count int
threshold int
}
// Acquire increments the counter. It will not block.
func (t *threshold) Acquire() {
t.mu.Lock()
t.count++
t.mu.Unlock()
}
// Release decrements the counter.
func (t *threshold) Release() {
t.mu.Lock()
if t.count == 0 {
panic("negative count")
}
if t.threshold == t.count {
t.cond.Broadcast()
}
t.count--
t.mu.Unlock()
}
// Wait waits for the counter to drop below the threshold
func (t *threshold) Wait() {
t.mu.Lock()
for t.count >= t.threshold {
t.cond.Wait()
}
t.mu.Unlock()
}

Some files were not shown because too many files have changed in this diff Show More