Update dependencies (#2180)

* Update dependencies

* Fix whatsmeow API changes
This commit is contained in:
Wim
2024-08-27 19:04:05 +02:00
committed by GitHub
parent d16645c952
commit c4157a4d5b
589 changed files with 681707 additions and 198856 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2022 Google LLC
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.9
// protoc v4.24.4
// source: google/rpc/status.proto
package status

View File

@@ -66,7 +66,7 @@ How to get your contributions merged smoothly and quickly.
- **All tests need to be passing** before your change can be merged. We
recommend you **run tests locally** before creating your PR to catch breakages
early on.
- `VET_SKIP_PROTO=1 ./vet.sh` to catch vet errors
- `./scripts/vet.sh` to catch vet errors
- `go test -cpu 1,4 -timeout 7m ./...` to run the tests
- `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode

View File

@@ -9,6 +9,7 @@ for general contribution guidelines.
## Maintainers (in alphabetical order)
- [atollena](https://github.com/atollena), Datadog, Inc.
- [cesarghali](https://github.com/cesarghali), Google LLC
- [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC

View File

@@ -30,17 +30,20 @@ testdeps:
GO111MODULE=on go get -d -v -t google.golang.org/grpc/...
vet: vetdeps
./vet.sh
./scripts/vet.sh
vetdeps:
./vet.sh -install
./scripts/vet.sh -install
.PHONY: \
all \
build \
clean \
deps \
proto \
test \
testsubmodule \
testrace \
testdeps \
vet \
vetdeps

View File

@@ -10,7 +10,7 @@ RPC framework that puts mobile and HTTP/2 first. For more information see the
## Prerequisites
- **[Go][]**: any one of the **three latest major** [releases][go-releases].
- **[Go][]**: any one of the **two latest major** [releases][go-releases].
## Installation

View File

@@ -54,13 +54,14 @@ var (
// an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) {
if strings.ToLower(b.Name()) != b.Name() {
name := strings.ToLower(b.Name())
if name != b.Name() {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59
// is released to switch to case sensitive balancer registry. Also,
// remove this warning and update the docstrings for Register and Get.
logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon", b.Name())
}
m[strings.ToLower(b.Name())] = b
m[name] = b
}
// unregisterForTesting deletes the balancer with the given name from the
@@ -232,8 +233,8 @@ type BuildOptions struct {
// implementations which do not communicate with a remote load balancer
// server can ignore this field.
Authority string
// ChannelzParentID is the parent ClientConn's channelz ID.
ChannelzParentID *channelz.Identifier
// ChannelzParent is the parent ClientConn's channelz channel.
ChannelzParent channelz.Identifier
// CustomUserAgent is the custom user agent set on the parent ClientConn.
// The balancer should set the same custom user agent if it creates a
// ClientConn.

View File

@@ -16,54 +16,60 @@
*
*/
package grpc
// Package pickfirst contains the pick_first load balancing policy.
package pickfirst
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
const (
// PickFirstBalancerName is the name of the pick_first balancer.
PickFirstBalancerName = "pick_first"
logPrefix = "[pick-first-lb %p] "
)
func newPickfirstBuilder() balancer.Builder {
return &pickfirstBuilder{}
func init() {
balancer.Register(pickfirstBuilder{})
internal.ShuffleAddressListForTesting = func(n int, swap func(i, j int)) { rand.Shuffle(n, swap) }
}
var logger = grpclog.Component("pick-first-lb")
const (
// Name is the name of the pick_first balancer.
Name = "pick_first"
logPrefix = "[pick-first-lb %p] "
)
type pickfirstBuilder struct{}
func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
func (pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{cc: cc}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (*pickfirstBuilder) Name() string {
return PickFirstBalancerName
func (pickfirstBuilder) Name() string {
return Name
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of addresses received from the name resolver before attempting to
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
func (*pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
@@ -97,9 +103,14 @@ func (b *pickfirstBalancer) ResolverError(err error) {
})
}
type Shuffler interface {
ShuffleAddressListForTesting(n int, swap func(i, j int))
}
func ShuffleAddressListForTesting(n int, swap func(i, j int)) { rand.Shuffle(n, swap) }
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
addrs := state.ResolverState.Addresses
if len(addrs) == 0 {
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// The resolver reported an empty address list. Treat it like an error by
// calling b.ResolverError.
if b.subConn != nil {
@@ -111,22 +122,49 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
b.ResolverError(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
// We don't have to guard this block with the env var because ParseConfig
// already does so.
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig)
}
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var addrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling will
// change the order of endpoints but not touch the order of the addresses
// within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.ShuffleAddressListForTesting.(func(int, func(int, int)))(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for each
// of the endpoints, in order." - A61
for _, endpoint := range endpoints {
// "In the flattened list, interleave addresses from the two address
// families, as per RFC-8304 section 4." - A61
// TODO: support the above language.
addrs = append(addrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forwarrd the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
addrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
}
if b.subConn != nil {
b.cc.UpdateAddresses(b.subConn, addrs)
return nil
@@ -243,7 +281,3 @@ func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.subConn.Connect()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
func init() {
balancer.Register(newPickfirstBuilder())
}

View File

@@ -22,12 +22,12 @@
package roundrobin
import (
"math/rand"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcrand"
)
// Name is the name of round_robin balancer.
@@ -60,7 +60,7 @@ func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
// Start at a random index, as the same RR balancer rebuilds a new
// picker when SubConn states change, and we don't want to apply excess
// load to the first server in the list.
next: uint32(grpcrand.Intn(len(scs))),
next: uint32(rand.Intn(len(scs))),
}
}

View File

@@ -21,7 +21,6 @@ package grpc
import (
"context"
"fmt"
"strings"
"sync"
"google.golang.org/grpc/balancer"
@@ -66,19 +65,20 @@ type ccBalancerWrapper struct {
}
// newCCBalancerWrapper creates a new balancer wrapper in idle state. The
// underlying balancer is not created until the switchTo() method is invoked.
// underlying balancer is not created until the updateClientConnState() method
// is invoked.
func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
ctx, cancel := context.WithCancel(cc.ctx)
ccb := &ccBalancerWrapper{
cc: cc,
opts: balancer.BuildOptions{
DialCreds: cc.dopts.copts.TransportCredentials,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
DialCreds: cc.dopts.copts.TransportCredentials,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParent: cc.channelz,
Target: cc.parsedTarget,
},
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
@@ -97,6 +97,11 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
if ctx.Err() != nil || ccb.balancer == nil {
return
}
name := gracefulswitch.ChildName(ccs.BalancerConfig)
if ccb.curBalancerName != name {
ccb.curBalancerName = name
channelz.Infof(logger, ccb.cc.channelz, "Channel switches to new LB policy %q", name)
}
err := ccb.balancer.UpdateClientConnState(*ccs)
if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
@@ -120,54 +125,6 @@ func (ccb *ccBalancerWrapper) resolverError(err error) {
})
}
// switchTo is invoked by grpc to instruct the balancer wrapper to switch to the
// LB policy identified by name.
//
// ClientConn calls newCCBalancerWrapper() at creation time. Upon receipt of the
// first good update from the name resolver, it determines the LB policy to use
// and invokes the switchTo() method. Upon receipt of every subsequent update
// from the name resolver, it invokes this method.
//
// the ccBalancerWrapper keeps track of the current LB policy name, and skips
// the graceful balancer switching process if the name does not change.
func (ccb *ccBalancerWrapper) switchTo(name string) {
ccb.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
// TODO: Other languages use case-sensitive balancer registries. We should
// switch as well. See: https://github.com/grpc/grpc-go/issues/5288.
if strings.EqualFold(ccb.curBalancerName, name) {
return
}
ccb.buildLoadBalancingPolicy(name)
})
}
// buildLoadBalancingPolicy performs the following:
// - retrieve a balancer builder for the given name. Use the default LB
// policy, pick_first, if no LB policy with name is found in the registry.
// - instruct the gracefulswitch balancer to switch to the above builder. This
// will actually build the new balancer.
// - update the `curBalancerName` field
//
// Must be called from a serializer callback.
func (ccb *ccBalancerWrapper) buildLoadBalancingPolicy(name string) {
builder := balancer.Get(name)
if builder == nil {
channelz.Warningf(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q, since the specified LB policy %q was not registered", PickFirstBalancerName, name)
builder = newPickfirstBuilder()
} else {
channelz.Infof(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q", name)
}
if err := ccb.balancer.SwitchTo(builder); err != nil {
channelz.Errorf(logger, ccb.cc.channelzID, "Channel failed to build new LB policy %q: %v", name, err)
return
}
ccb.curBalancerName = builder.Name()
}
// close initiates async shutdown of the wrapper. cc.mu must be held when
// calling this function. To determine the wrapper has finished shutting down,
// the channel should block on ccb.serializer.Done() without cc.mu held.
@@ -175,7 +132,7 @@ func (ccb *ccBalancerWrapper) close() {
ccb.mu.Lock()
ccb.closed = true
ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: closing")
channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing")
ccb.serializer.Schedule(func(context.Context) {
if ccb.balancer == nil {
return
@@ -212,7 +169,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
}
ac, err := ccb.cc.newAddrConnLocked(addrs, opts)
if err != nil {
channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
channelz.Warningf(logger, ccb.cc.channelz, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err
}
acbw := &acBalancerWrapper{
@@ -241,6 +198,10 @@ func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resol
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock()
if ccb.cc.conns == nil {
// The CC has been closed; ignore this update.
return
}
ccb.mu.Lock()
if ccb.closed {
@@ -304,7 +265,7 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
}
func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int())
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelz.ID)
}
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {

View File

@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc-gen-go v1.34.1
// protoc v4.25.2
// source: grpc/binlog/v1/binarylog.proto

View File

@@ -31,13 +31,13 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle"
"google.golang.org/grpc/internal/pretty"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
@@ -67,12 +67,14 @@ var (
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
// errConnIdling indicates the the connection is being closed as the channel
// errConnIdling indicates the connection is being closed as the channel
// is moving to an idle mode due to inactivity.
errConnIdling = errors.New("grpc: the connection is closing due to channel idleness")
// invalidDefaultServiceConfigErrPrefix is used to prefix the json parsing error for the default
// service config.
invalidDefaultServiceConfigErrPrefix = "grpc: the provided default service config is invalid"
// PickFirstBalancerName is the name of the pick_first balancer.
PickFirstBalancerName = pickfirst.Name
)
// The following errors are returned from Dial and DialContext
@@ -101,11 +103,6 @@ const (
defaultReadBufSize = 32 * 1024
)
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
type defaultConfigSelector struct {
sc *ServiceConfig
}
@@ -117,13 +114,23 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires
}, nil
}
// newClient returns a new client in idle mode.
func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
// NewClient creates a new gRPC "channel" for the target URI provided. No I/O
// is performed. Use of the ClientConn for RPCs will automatically cause it to
// connect. Connect may be used to manually create a connection, but for most
// users this is unnecessary.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md. e.g. to use dns
// resolver, a "dns:///" prefix should be applied to the target.
//
// The DialOptions returned by WithBlock, WithTimeout,
// WithReturnConnectionError, and FailOnNonTempDialError are ignored by this
// function.
func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(),
czData: new(channelzData),
}
cc.retryThrottler.Store((*retryThrottler)(nil))
@@ -148,6 +155,16 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
for _, opt := range opts {
opt.apply(&cc.dopts)
}
// Determine the resolver to use.
if err := cc.initParsedTargetAndResolverBuilder(); err != nil {
return nil, err
}
for _, opt := range globalPerTargetDialOptions {
opt.DialOptionForTarget(cc.parsedTarget.URL).apply(&cc.dopts)
}
chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)
@@ -156,7 +173,7 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
}
if cc.dopts.defaultServiceConfigRawJSON != nil {
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON)
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON, cc.dopts.maxCallAttempts)
if scpr.Err != nil {
return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
}
@@ -164,26 +181,17 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
}
cc.mkp = cc.dopts.copts.KeepaliveParams
// Register ClientConn with channelz.
if err = cc.initAuthority(); err != nil {
return nil, err
}
// Register ClientConn with channelz. Note that this is only done after
// channel creation cannot fail.
cc.channelzRegistration(target)
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", cc.parsedTarget)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
// TODO: Ideally it should be impossible to error from this function after
// channelz registration. This will require removing some channelz logs
// from the following functions that can error. Errors can be returned to
// the user, and successful logs can be emitted here, after the checks have
// passed and channelz is subsequently registered.
// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
if err = cc.determineAuthority(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelzID)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
@@ -191,39 +199,36 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
return cc, nil
}
// DialContext creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
// established, and connecting happens in the background). To make it a blocking
// dial, use WithBlock() dial option.
// Dial calls DialContext(context.Background(), target, opts...).
//
// In the non-blocking case, the ctx does not act against the connection. It
// only controls the setup steps.
// Deprecated: use NewClient instead. Will be supported throughout 1.x.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext calls NewClient and then exits idle mode. If WithBlock(true) is
// used, it calls Connect and WaitForStateChange until either the context
// expires or the state of the ClientConn is Ready.
//
// In the blocking case, ctx can be used to cancel or expire the pending
// connection. Once this function returns, the cancellation and expiration of
// ctx will be noop. Users should call ClientConn.Close to terminate all the
// pending operations after this function returns.
// One subtle difference between NewClient and Dial and DialContext is that the
// former uses "dns" as the default name resolver, while the latter use
// "passthrough" for backward compatibility. This distinction should not matter
// to most users, but could matter to legacy users that specify a custom dialer
// and expect it to receive the target string directly.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
// Deprecated: use NewClient instead. Will be supported throughout 1.x.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc, err := newClient(target, opts...)
// At the end of this method, we kick the channel out of idle, rather than
// waiting for the first rpc.
opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...)
cc, err := NewClient(target, opts...)
if err != nil {
return nil, err
}
// We start the channel off in idle mode, but kick it out of idle now,
// instead of waiting for the first RPC. Other gRPC implementations do wait
// for the first RPC to kick the channel out of idle. But doing so would be
// a major behavior change for our users who are used to seeing the channel
// active after Dial.
//
// Taking this approach of kicking it out of idle at the end of this method
// allows us to share the code between channel creation and exiting idle
// mode. This will also make it easy for us to switch to starting the
// channel off in idle, i.e. by making newClient exported.
// instead of waiting for the first RPC. This is the legacy behavior of
// Dial.
defer func() {
if err != nil {
cc.Close()
@@ -291,17 +296,17 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
// addTraceEvent is a helper method to add a trace event on the channel. If the
// channel is a nested one, the same event is also added on the parent channel.
func (cc *ClientConn) addTraceEvent(msg string) {
ted := &channelz.TraceEventDesc{
ted := &channelz.TraceEvent{
Desc: fmt.Sprintf("Channel %s", msg),
Severity: channelz.CtInfo,
}
if cc.dopts.channelzParentID != nil {
ted.Parent = &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested channel(id:%d) %s", cc.channelzID.Int(), msg),
if cc.dopts.channelzParent != nil {
ted.Parent = &channelz.TraceEvent{
Desc: fmt.Sprintf("Nested channel(id:%d) %s", cc.channelz.ID, msg),
Severity: channelz.CtInfo,
}
}
channelz.AddTraceEvent(logger, cc.channelzID, 0, ted)
channelz.AddTraceEvent(logger, cc.channelz, 0, ted)
}
type idler ClientConn
@@ -418,14 +423,15 @@ func (cc *ClientConn) validateTransportCredentials() error {
}
// channelzRegistration registers the newly created ClientConn with channelz and
// stores the returned identifier in `cc.channelzID` and `cc.csMgr.channelzID`.
// A channelz trace event is emitted for ClientConn creation. If the newly
// created ClientConn is a nested one, i.e a valid parent ClientConn ID is
// specified via a dial option, the trace event is also added to the parent.
// stores the returned identifier in `cc.channelz`. A channelz trace event is
// emitted for ClientConn creation. If the newly created ClientConn is a nested
// one, i.e a valid parent ClientConn ID is specified via a dial option, the
// trace event is also added to the parent.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) channelzRegistration(target string) {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target)
parentChannel, _ := cc.dopts.channelzParent.(*channelz.Channel)
cc.channelz = channelz.RegisterChannel(parentChannel, target)
cc.addTraceEvent("created")
}
@@ -492,11 +498,11 @@ func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStr
}
// newConnectivityStateManager creates an connectivityStateManager with
// the specified id.
func newConnectivityStateManager(ctx context.Context, id *channelz.Identifier) *connectivityStateManager {
// the specified channel.
func newConnectivityStateManager(ctx context.Context, channel *channelz.Channel) *connectivityStateManager {
return &connectivityStateManager{
channelzID: id,
pubSub: grpcsync.NewPubSub(ctx),
channelz: channel,
pubSub: grpcsync.NewPubSub(ctx),
}
}
@@ -510,7 +516,7 @@ type connectivityStateManager struct {
mu sync.Mutex
state connectivity.State
notifyChan chan struct{}
channelzID *channelz.Identifier
channelz *channelz.Channel
pubSub *grpcsync.PubSub
}
@@ -527,9 +533,10 @@ func (csm *connectivityStateManager) updateState(state connectivity.State) {
return
}
csm.state = state
csm.channelz.ChannelMetrics.State.Store(&state)
csm.pubSub.Publish(state)
channelz.Infof(logger, csm.channelzID, "Channel Connectivity change to %v", state)
channelz.Infof(logger, csm.channelz, "Channel Connectivity change to %v", state)
if csm.notifyChan != nil {
// There are other goroutines waiting on this channel.
close(csm.notifyChan)
@@ -583,12 +590,12 @@ type ClientConn struct {
cancel context.CancelFunc // Cancelled on close.
// The following are initialized at dial time, and are read-only after that.
target string // User's dial target.
parsedTarget resolver.Target // See parseTargetAndFindResolver().
authority string // See determineAuthority().
dopts dialOptions // Default and user specified dial options.
channelzID *channelz.Identifier // Channelz identifier for the channel.
resolverBuilder resolver.Builder // See parseTargetAndFindResolver().
target string // User's dial target.
parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder().
authority string // See initAuthority().
dopts dialOptions // Default and user specified dial options.
channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager
// The following provide their own synchronization, and therefore don't
@@ -596,7 +603,6 @@ type ClientConn struct {
csMgr *connectivityStateManager
pickerWrapper *pickerWrapper
safeConfigSelector iresolver.SafeConfigSelector
czData *channelzData
retryThrottler atomic.Value // Updated from service config.
// mu protects the following fields.
@@ -690,7 +696,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
var emptyServiceConfig *ServiceConfig
func init() {
cfg := parseServiceConfig("{}")
cfg := parseServiceConfig("{}", defaultMaxCallAttempts)
if cfg.Err != nil {
panic(fmt.Sprintf("impossible error parsing empty service config: %v", cfg.Err))
}
@@ -707,15 +713,15 @@ func init() {
}
}
func (cc *ClientConn) maybeApplyDefaultServiceConfig(addrs []resolver.Address) {
func (cc *ClientConn) maybeApplyDefaultServiceConfig() {
if cc.sc != nil {
cc.applyServiceConfigAndBalancer(cc.sc, nil, addrs)
cc.applyServiceConfigAndBalancer(cc.sc, nil)
return
}
if cc.dopts.defaultServiceConfig != nil {
cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, &defaultConfigSelector{cc.dopts.defaultServiceConfig}, addrs)
cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, &defaultConfigSelector{cc.dopts.defaultServiceConfig})
} else {
cc.applyServiceConfigAndBalancer(emptyServiceConfig, &defaultConfigSelector{emptyServiceConfig}, addrs)
cc.applyServiceConfigAndBalancer(emptyServiceConfig, &defaultConfigSelector{emptyServiceConfig})
}
}
@@ -733,7 +739,7 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
// May need to apply the initial service config in case the resolver
// doesn't support service configs, or doesn't provide a service config
// with the new addresses.
cc.maybeApplyDefaultServiceConfig(nil)
cc.maybeApplyDefaultServiceConfig()
cc.balancerWrapper.resolverError(err)
@@ -744,10 +750,10 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
var ret error
if cc.dopts.disableServiceConfig {
channelz.Infof(logger, cc.channelzID, "ignoring service config from resolver (%v) and applying the default because service config is disabled", s.ServiceConfig)
cc.maybeApplyDefaultServiceConfig(s.Addresses)
channelz.Infof(logger, cc.channelz, "ignoring service config from resolver (%v) and applying the default because service config is disabled", s.ServiceConfig)
cc.maybeApplyDefaultServiceConfig()
} else if s.ServiceConfig == nil {
cc.maybeApplyDefaultServiceConfig(s.Addresses)
cc.maybeApplyDefaultServiceConfig()
// TODO: do we need to apply a failing LB policy if there is no
// default, per the error handling design?
} else {
@@ -755,12 +761,12 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
configSelector := iresolver.GetConfigSelector(s)
if configSelector != nil {
if len(s.ServiceConfig.Config.(*ServiceConfig).Methods) != 0 {
channelz.Infof(logger, cc.channelzID, "method configs in service config will be ignored due to presence of config selector")
channelz.Infof(logger, cc.channelz, "method configs in service config will be ignored due to presence of config selector")
}
} else {
configSelector = &defaultConfigSelector{sc}
}
cc.applyServiceConfigAndBalancer(sc, configSelector, s.Addresses)
cc.applyServiceConfigAndBalancer(sc, configSelector)
} else {
ret = balancer.ErrBadResolverState
if cc.sc == nil {
@@ -775,7 +781,7 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
var balCfg serviceconfig.LoadBalancingConfig
if cc.sc != nil && cc.sc.lbConfig != nil {
balCfg = cc.sc.lbConfig.cfg
balCfg = cc.sc.lbConfig
}
bw := cc.balancerWrapper
cc.mu.Unlock()
@@ -834,22 +840,20 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
addrs: copyAddressesWithoutBalancerAttributes(addrs),
scopts: opts,
dopts: cc.dopts,
czData: new(channelzData),
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
stateChan: make(chan struct{}),
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
// we connect to different addresses.
ac.channelz.ChannelMetrics.Target.Store(&addrs[0].Addr)
var err error
ac.channelzID, err = channelz.RegisterSubChannel(ac, cc.channelzID, "")
if err != nil {
return nil, err
}
channelz.AddTraceEvent(logger, ac.channelzID, 0, &channelz.TraceEventDesc{
channelz.AddTraceEvent(logger, ac.channelz, 0, &channelz.TraceEvent{
Desc: "Subchannel created",
Severity: channelz.CtInfo,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Subchannel(id:%d) created", ac.channelzID.Int()),
Parent: &channelz.TraceEvent{
Desc: fmt.Sprintf("Subchannel(id:%d) created", ac.channelz.ID),
Severity: channelz.CtInfo,
},
})
@@ -872,38 +876,27 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) {
ac.tearDown(err)
}
func (cc *ClientConn) channelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: cc.GetState(),
Target: cc.target,
CallsStarted: atomic.LoadInt64(&cc.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&cc.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&cc.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&cc.czData.lastCallStartedTime)),
}
}
// Target returns the target string of the ClientConn.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (cc *ClientConn) Target() string {
return cc.target
}
// CanonicalTarget returns the canonical target string of the ClientConn.
func (cc *ClientConn) CanonicalTarget() string {
return cc.parsedTarget.String()
}
func (cc *ClientConn) incrCallsStarted() {
atomic.AddInt64(&cc.czData.callsStarted, 1)
atomic.StoreInt64(&cc.czData.lastCallStartedTime, time.Now().UnixNano())
cc.channelz.ChannelMetrics.CallsStarted.Add(1)
cc.channelz.ChannelMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
}
func (cc *ClientConn) incrCallsSucceeded() {
atomic.AddInt64(&cc.czData.callsSucceeded, 1)
cc.channelz.ChannelMetrics.CallsSucceeded.Add(1)
}
func (cc *ClientConn) incrCallsFailed() {
atomic.AddInt64(&cc.czData.callsFailed, 1)
cc.channelz.ChannelMetrics.CallsFailed.Add(1)
}
// connect starts creating a transport.
@@ -946,10 +939,14 @@ func equalAddresses(a, b []resolver.Address) bool {
// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.mu.Lock()
channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", pretty.ToJSON(ac.curAddr), pretty.ToJSON(addrs))
addrs = copyAddressesWithoutBalancerAttributes(addrs)
limit := len(addrs)
if limit > 5 {
limit = 5
}
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
ac.mu.Lock()
if equalAddresses(ac.addrs, addrs) {
ac.mu.Unlock()
return
@@ -1067,7 +1064,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method st
})
}
func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector, addrs []resolver.Address) {
func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) {
if sc == nil {
// should never reach here.
return
@@ -1088,17 +1085,6 @@ func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSel
} else {
cc.retryThrottler.Store((*retryThrottler)(nil))
}
var newBalancerName string
if cc.sc == nil || (cc.sc.lbConfig == nil && cc.sc.LB == nil) {
// No service config or no LB policy specified in config.
newBalancerName = PickFirstBalancerName
} else if cc.sc.lbConfig != nil {
newBalancerName = cc.sc.lbConfig.name
} else { // cc.sc.LB != nil
newBalancerName = *cc.sc.LB
}
cc.balancerWrapper.switchTo(newBalancerName)
}
func (cc *ClientConn) resolveNow(o resolver.ResolveNowOptions) {
@@ -1174,7 +1160,7 @@ func (cc *ClientConn) Close() error {
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from being
// deleted right away.
channelz.RemoveEntry(cc.channelzID)
channelz.RemoveEntry(cc.channelz.ID)
return nil
}
@@ -1195,6 +1181,10 @@ type addrConn struct {
// is received, transport is closed, ac has been torn down).
transport transport.ClientTransport // The current transport.
// This mutex is used on the RPC path, so its usage should be minimized as
// much as possible.
// TODO: Find a lock-free way to retrieve the transport and state from the
// addrConn.
mu sync.Mutex
curAddr resolver.Address // The current address.
addrs []resolver.Address // All addresses that the resolver resolved to.
@@ -1206,8 +1196,7 @@ type addrConn struct {
backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{}
channelzID *channelz.Identifier
czData *channelzData
channelz *channelz.SubChannel
}
// Note: this requires a lock on ac.mu.
@@ -1219,10 +1208,11 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
close(ac.stateChan)
ac.stateChan = make(chan struct{})
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v", s)
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v", s)
} else {
channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, lastErr)
}
@@ -1320,6 +1310,7 @@ func (ac *addrConn) resetTransport() {
func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error {
var firstConnErr error
for _, addr := range addrs {
ac.channelz.ChannelMetrics.Target.Store(&addr.Addr)
if ctx.Err() != nil {
return errConnClosing
}
@@ -1335,7 +1326,7 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
}
ac.mu.Unlock()
channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr)
channelz.Infof(logger, ac.channelz, "Subchannel picks a new address %q to connect", addr.Addr)
err := ac.createTransport(ctx, addr, copts, connectDeadline)
if err == nil {
@@ -1388,7 +1379,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
defer cancel()
copts.ChannelzParentID = ac.channelzID
copts.ChannelzParent = ac.channelz
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose)
if err != nil {
@@ -1397,7 +1388,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
}
// newTr is either nil, or closed.
hcancel()
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
return err
}
@@ -1469,7 +1460,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
// The health package is not imported to set health check function.
//
// TODO: add a link to the health check doc in the error message.
channelz.Error(logger, ac.channelzID, "Health check is requested but health check function is not set.")
channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
return
}
@@ -1499,9 +1490,9 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if err != nil {
if status.Code(err) == codes.Unimplemented {
channelz.Error(logger, ac.channelzID, "Subchannel health check is unimplemented at server side, thus health check is disabled")
channelz.Error(logger, ac.channelz, "Subchannel health check is unimplemented at server side, thus health check is disabled")
} else {
channelz.Errorf(logger, ac.channelzID, "Health checking failed: %v", err)
channelz.Errorf(logger, ac.channelz, "Health checking failed: %v", err)
}
}
}()
@@ -1566,18 +1557,18 @@ func (ac *addrConn) tearDown(err error) {
ac.cancel()
ac.curAddr = resolver.Address{}
channelz.AddTraceEvent(logger, ac.channelzID, 0, &channelz.TraceEventDesc{
channelz.AddTraceEvent(logger, ac.channelz, 0, &channelz.TraceEvent{
Desc: "Subchannel deleted",
Severity: channelz.CtInfo,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Subchannel(id:%d) deleted", ac.channelzID.Int()),
Parent: &channelz.TraceEvent{
Desc: fmt.Sprintf("Subchannel(id:%d) deleted", ac.channelz.ID),
Severity: channelz.CtInfo,
},
})
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from
// being deleted right away.
channelz.RemoveEntry(ac.channelzID)
channelz.RemoveEntry(ac.channelz.ID)
ac.mu.Unlock()
// We have to release the lock before the call to GracefulClose/Close here
@@ -1604,39 +1595,6 @@ func (ac *addrConn) tearDown(err error) {
}
}
func (ac *addrConn) getState() connectivity.State {
ac.mu.Lock()
defer ac.mu.Unlock()
return ac.state
}
func (ac *addrConn) ChannelzMetric() *channelz.ChannelInternalMetric {
ac.mu.Lock()
addr := ac.curAddr.Addr
ac.mu.Unlock()
return &channelz.ChannelInternalMetric{
State: ac.getState(),
Target: addr,
CallsStarted: atomic.LoadInt64(&ac.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&ac.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&ac.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&ac.czData.lastCallStartedTime)),
}
}
func (ac *addrConn) incrCallsStarted() {
atomic.AddInt64(&ac.czData.callsStarted, 1)
atomic.StoreInt64(&ac.czData.lastCallStartedTime, time.Now().UnixNano())
}
func (ac *addrConn) incrCallsSucceeded() {
atomic.AddInt64(&ac.czData.callsSucceeded, 1)
}
func (ac *addrConn) incrCallsFailed() {
atomic.AddInt64(&ac.czData.callsFailed, 1)
}
type retryThrottler struct {
max float64
thresh float64
@@ -1674,12 +1632,17 @@ func (rt *retryThrottler) successfulRPC() {
}
}
type channelzChannel struct {
cc *ClientConn
func (ac *addrConn) incrCallsStarted() {
ac.channelz.ChannelMetrics.CallsStarted.Add(1)
ac.channelz.ChannelMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
}
func (c *channelzChannel) ChannelzMetric() *channelz.ChannelInternalMetric {
return c.cc.channelzMetric()
func (ac *addrConn) incrCallsSucceeded() {
ac.channelz.ChannelMetrics.CallsSucceeded.Add(1)
}
func (ac *addrConn) incrCallsFailed() {
ac.channelz.ChannelMetrics.CallsFailed.Add(1)
}
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
@@ -1713,22 +1676,19 @@ func (cc *ClientConn) connectionError() error {
return cc.lastConnectionError
}
// parseTargetAndFindResolver parses the user's dial target and stores the
// parsed target in `cc.parsedTarget`.
// initParsedTargetAndResolverBuilder parses the user's dial target and stores
// the parsed target in `cc.parsedTarget`.
//
// The resolver to use is determined based on the scheme in the parsed target
// and the same is stored in `cc.resolverBuilder`.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) parseTargetAndFindResolver() error {
channelz.Infof(logger, cc.channelzID, "original dial target is: %q", cc.target)
func (cc *ClientConn) initParsedTargetAndResolverBuilder() error {
logger.Infof("original dial target is: %q", cc.target)
var rb resolver.Builder
parsedTarget, err := parseTarget(cc.target)
if err != nil {
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", cc.target, err)
} else {
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %#v", parsedTarget)
if err == nil {
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb != nil {
cc.parsedTarget = parsedTarget
@@ -1740,17 +1700,19 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
// We are here because the user's dial target did not contain a scheme or
// specified an unregistered scheme. We should fallback to the default
// scheme, except when a custom dialer is specified in which case, we should
// always use passthrough scheme.
defScheme := resolver.GetDefaultScheme()
channelz.Infof(logger, cc.channelzID, "fallback to scheme %q", defScheme)
// always use passthrough scheme. For either case, we need to respect any overridden
// global defaults set by the user.
defScheme := cc.dopts.defaultScheme
if internal.UserSetDefaultScheme {
defScheme = resolver.GetDefaultScheme()
}
canonicalTarget := defScheme + ":///" + cc.target
parsedTarget, err = parseTarget(canonicalTarget)
if err != nil {
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", canonicalTarget, err)
return err
}
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %+v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb == nil {
return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme)
@@ -1772,6 +1734,8 @@ func parseTarget(target string) (resolver.Target, error) {
return resolver.Target{URL: *u}, nil
}
// encodeAuthority escapes the authority string based on valid chars defined in
// https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.
func encodeAuthority(authority string) string {
const upperhex = "0123456789ABCDEF"
@@ -1788,7 +1752,7 @@ func encodeAuthority(authority string) string {
return false
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // Subdelim characters
return false
case ':', '[', ']', '@': // Authority related delimeters
case ':', '[', ']', '@': // Authority related delimiters
return false
}
// Everything else must be escaped.
@@ -1838,7 +1802,7 @@ func encodeAuthority(authority string) string {
// credentials do not match the authority configured through the dial option.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) determineAuthority() error {
func (cc *ClientConn) initAuthority() error {
dopts := cc.dopts
// Historically, we had two options for users to specify the serverName or
// authority for a channel. One was through the transport credentials
@@ -1871,6 +1835,5 @@ func (cc *ClientConn) determineAuthority() error {
} else {
cc.authority = encodeAuthority(endpoint)
}
channelz.Infof(logger, cc.channelzID, "Channel authority set to %q", cc.authority)
return nil
}

View File

@@ -1,17 +0,0 @@
#!/usr/bin/env bash
# This script serves as an example to demonstrate how to generate the gRPC-Go
# interface and the related messages from .proto file.
#
# It assumes the installation of i) Google proto buffer compiler at
# https://github.com/google/protobuf (after v2.6.1) and ii) the Go codegen
# plugin at https://github.com/golang/protobuf (after 2015-02-20). If you have
# not, please install them first.
#
# We recommend running this script at $GOPATH/src.
#
# If this is not what you need, feel free to make your own scripts. Again, this
# script is for demonstration purpose.
#
proto=$1
protoc --go_out=plugins=grpc:. $proto

View File

@@ -235,7 +235,7 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci)
return fmt.Errorf("invalid code: %d", ci)
}
*c = Code(ci)

View File

@@ -28,9 +28,9 @@ import (
"fmt"
"net"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/attributes"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/protobuf/proto"
)
// PerRPCCredentials defines the common interface for the credentials which need to
@@ -237,7 +237,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
}
// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// It returns success if 1) the condition is satisfied or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.

View File

@@ -27,9 +27,13 @@ import (
"net/url"
"os"
"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
)
var logger = grpclog.Component("credentials")
// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
@@ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
conn.Close()
return nil, nil, ctx.Err()
}
// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
}
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: CommonAuthInfo{
@@ -131,8 +151,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
conn.Close()
return nil, nil, err
}
cs := conn.ConnectionState()
// The negotiated application protocol can be empty only if the client doesn't
// support ALPN. In such cases, we can close the connection since ALPN is required
// for using HTTP/2 over TLS.
if cs.NegotiatedProtocol == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
} else if logger.V(2) {
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
}
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
State: cs,
CommonAuthInfo: CommonAuthInfo{
SecurityLevel: PrivacyAndIntegrity,
},

View File

@@ -21,6 +21,7 @@ package grpc
import (
"context"
"net"
"net/url"
"time"
"google.golang.org/grpc/backoff"
@@ -36,6 +37,11 @@ import (
"google.golang.org/grpc/stats"
)
const (
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#limits-on-retries-and-hedges
defaultMaxCallAttempts = 5
)
func init() {
internal.AddGlobalDialOptions = func(opt ...DialOption) {
globalDialOptions = append(globalDialOptions, opt...)
@@ -43,6 +49,14 @@ func init() {
internal.ClearGlobalDialOptions = func() {
globalDialOptions = nil
}
internal.AddGlobalPerTargetDialOptions = func(opt any) {
if ptdo, ok := opt.(perTargetDialOption); ok {
globalPerTargetDialOptions = append(globalPerTargetDialOptions, ptdo)
}
}
internal.ClearGlobalPerTargetDialOptions = func() {
globalPerTargetDialOptions = nil
}
internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
@@ -68,7 +82,7 @@ type dialOptions struct {
binaryLogger binarylog.Logger
copts transport.ConnectOptions
callOptions []CallOption
channelzParentID *channelz.Identifier
channelzParent channelz.Identifier
disableServiceConfig bool
disableRetry bool
disableHealthCheck bool
@@ -79,6 +93,8 @@ type dialOptions struct {
resolvers []resolver.Builder
idleTimeout time.Duration
recvBufferPool SharedBufferPool
defaultScheme string
maxCallAttempts int
}
// DialOption configures how we set up the connection.
@@ -88,6 +104,19 @@ type DialOption interface {
var globalDialOptions []DialOption
// perTargetDialOption takes a parsed target and returns a dial option to apply.
//
// This gets called after NewClient() parses the target, and allows per target
// configuration set through a returned DialOption. The DialOption will not take
// effect if specifies a resolver builder, as that Dial Option is factored in
// while parsing target.
type perTargetDialOption interface {
// DialOption returns a Dial Option to apply.
DialOptionForTarget(parsedTarget url.URL) DialOption
}
var globalPerTargetDialOptions []perTargetDialOption
// EmptyDialOption does not alter the dial configuration. It can be embedded in
// another structure to build custom dial options.
//
@@ -154,9 +183,7 @@ func WithSharedWriteBuffer(val bool) DialOption {
}
// WithWriteBufferSize determines how much data can be batched before doing a
// write on the wire. The corresponding memory allocation for this buffer will
// be twice the size to keep syscalls low. The default value for this buffer is
// 32KB.
// write on the wire. The default value for this buffer is 32KB.
//
// Zero or negative values will disable the write buffer such that each write
// will be on underlying connection. Note: A Send call may not directly
@@ -301,6 +328,9 @@ func withBackoff(bs internalbackoff.Strategy) DialOption {
//
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithBlock() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.block = true
@@ -315,10 +345,8 @@ func WithBlock() DialOption {
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithReturnConnectionError() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.block = true
@@ -388,8 +416,8 @@ func WithCredentialsBundle(b credentials.Bundle) DialOption {
// WithTimeout returns a DialOption that configures a timeout for dialing a
// ClientConn initially. This is valid if and only if WithBlock() is present.
//
// Deprecated: use DialContext instead of Dial and context.WithTimeout
// instead. Will be supported throughout 1.x.
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.timeout = d
@@ -471,9 +499,8 @@ func withBinaryLogger(bl binarylog.Logger) DialOption {
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// Deprecated: this DialOption is not supported by NewClient.
// This API may be changed or removed in a
// later release.
func FailOnNonTempDialError(f bool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
@@ -555,9 +582,9 @@ func WithAuthority(a string) DialOption {
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithChannelzParentID(id *channelz.Identifier) DialOption {
func WithChannelzParentID(c channelz.Identifier) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.channelzParentID = id
o.channelzParent = c
})
}
@@ -602,12 +629,22 @@ func WithDisableRetry() DialOption {
})
}
// MaxHeaderListSizeDialOption is a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
type MaxHeaderListSizeDialOption struct {
MaxHeaderListSize uint32
}
func (o MaxHeaderListSizeDialOption) apply(do *dialOptions) {
do.copts.MaxHeaderListSize = &o.MaxHeaderListSize
}
// WithMaxHeaderListSize returns a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
func WithMaxHeaderListSize(s uint32) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.MaxHeaderListSize = &s
})
return MaxHeaderListSizeDialOption{
MaxHeaderListSize: s,
}
}
// WithDisableHealthCheck disables the LB channel health checking for all
@@ -645,10 +682,12 @@ func defaultDialOptions() dialOptions {
healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
}
}
// withGetMinConnectDeadline specifies the function that clientconn uses to
// withMinConnectDeadline specifies the function that clientconn uses to
// get minConnectDeadline. This can be used to make connection attempts happen
// faster/slower.
//
@@ -659,6 +698,14 @@ func withMinConnectDeadline(f func() time.Duration) DialOption {
})
}
// withDefaultScheme is used to allow Dial to use "passthrough" as the default
// name resolver, while NewClient uses "dns" otherwise.
func withDefaultScheme(s string) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.defaultScheme = s
})
}
// WithResolvers allows a list of resolver implementations to be registered
// locally with the ClientConn without needing to be globally registered via
// resolver.Register. They will be matched against the scheme used for the
@@ -694,6 +741,23 @@ func WithIdleTimeout(d time.Duration) DialOption {
})
}
// WithMaxCallAttempts returns a DialOption that configures the maximum number
// of attempts per call (including retries and hedging) using the channel.
// Service owners may specify a higher value for these parameters, but higher
// values will be treated as equal to the maximum value by the client
// implementation. This mitigates security concerns related to the service
// config being transferred to the client via DNS.
//
// A value of 5 will be used if this dial option is not set or n < 2.
func WithMaxCallAttempts(n int) DialOption {
return newFuncDialOption(func(o *dialOptions) {
if n < 2 {
n = defaultMaxCallAttempts
}
o.maxCallAttempts = n
})
}
// WithRecvBufferPool returns a DialOption that configures the ClientConn
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.

View File

@@ -17,7 +17,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc-gen-go v1.34.1
// protoc v4.25.2
// source: grpc/health/v1/health.proto

View File

@@ -17,7 +17,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc-gen-go-grpc v1.4.0
// - protoc v4.25.2
// source: grpc/health/v1/health.proto
@@ -32,8 +32,8 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.62.0 or later.
const _ = grpc.SupportPackageIsVersion8
const (
Health_Check_FullMethodName = "/grpc.health.v1.Health/Check"
@@ -43,6 +43,10 @@ const (
// HealthClient is the client API for Health service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthClient interface {
// Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does
@@ -81,8 +85,9 @@ func NewHealthClient(cc grpc.ClientConnInterface) HealthClient {
}
func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -90,11 +95,12 @@ func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts .
}
func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) {
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, opts...)
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &healthWatchClient{stream}
x := &healthWatchClient{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
@@ -124,6 +130,10 @@ func (x *healthWatchClient) Recv() (*HealthCheckResponse, error) {
// HealthServer is the server API for Health service.
// All implementations should embed UnimplementedHealthServer
// for forward compatibility
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthServer interface {
// Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does
@@ -198,7 +208,7 @@ func _Health_Watch_Handler(srv interface{}, stream grpc.ServerStream) error {
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(HealthServer).Watch(m, &healthWatchServer{stream})
return srv.(HealthServer).Watch(m, &healthWatchServer{ServerStream: stream})
}
type Health_WatchServer interface {

View File

@@ -25,10 +25,10 @@ package backoff
import (
"context"
"errors"
"math/rand"
"time"
grpcbackoff "google.golang.org/grpc/backoff"
"google.golang.org/grpc/internal/grpcrand"
)
// Strategy defines the methodology for backing off after a grpc connection
@@ -67,7 +67,7 @@ func (bc Exponential) Backoff(retries int) time.Duration {
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + bc.Config.Jitter*(grpcrand.Float64()*2-1)
backoff *= 1 + bc.Config.Jitter*(rand.Float64()*2-1)
if backoff < 0 {
return 0
}

View File

@@ -0,0 +1,82 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package gracefulswitch
import (
"encoding/json"
"fmt"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/serviceconfig"
)
type lbConfig struct {
serviceconfig.LoadBalancingConfig
childBuilder balancer.Builder
childConfig serviceconfig.LoadBalancingConfig
}
func ChildName(l serviceconfig.LoadBalancingConfig) string {
return l.(*lbConfig).childBuilder.Name()
}
// ParseConfig parses a child config list and returns a LB config for the
// gracefulswitch Balancer.
//
// cfg is expected to be a json.RawMessage containing a JSON array of LB policy
// names + configs as the format of the "loadBalancingConfig" field in
// ServiceConfig. It returns a type that should be passed to
// UpdateClientConnState in the BalancerConfig field.
func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var lbCfg []map[string]json.RawMessage
if err := json.Unmarshal(cfg, &lbCfg); err != nil {
return nil, err
}
for i, e := range lbCfg {
if len(e) != 1 {
return nil, fmt.Errorf("expected a JSON struct with one entry; received entry %v at index %d", e, i)
}
var name string
var jsonCfg json.RawMessage
for name, jsonCfg = range e {
}
builder := balancer.Get(name)
if builder == nil {
// Skip unregistered balancer names.
continue
}
parser, ok := builder.(balancer.ConfigParser)
if !ok {
// This is a valid child with no config.
return &lbConfig{childBuilder: builder}, nil
}
cfg, err := parser.ParseConfig(jsonCfg)
if err != nil {
return nil, fmt.Errorf("error parsing config for policy %q: %v", name, err)
}
return &lbConfig{childBuilder: builder, childConfig: cfg}, nil
}
return nil, fmt.Errorf("no supported policies found in config: %v", string(cfg))
}

View File

@@ -94,14 +94,23 @@ func (gsb *Balancer) balancerCurrentOrPending(bw *balancerWrapper) bool {
// process is not complete when this method returns. This method must be called
// synchronously alongside the rest of the balancer.Balancer methods this
// Graceful Switch Balancer implements.
//
// Deprecated: use ParseConfig and pass a parsed config to UpdateClientConnState
// to cause the Balancer to automatically change to the new child when necessary.
func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
_, err := gsb.switchTo(builder)
return err
}
func (gsb *Balancer) switchTo(builder balancer.Builder) (*balancerWrapper, error) {
gsb.mu.Lock()
if gsb.closed {
gsb.mu.Unlock()
return errBalancerClosed
return nil, errBalancerClosed
}
bw := &balancerWrapper{
gsb: gsb,
builder: builder,
gsb: gsb,
lastState: balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
@@ -129,7 +138,7 @@ func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
gsb.balancerCurrent = nil
}
gsb.mu.Unlock()
return balancer.ErrBadResolverState
return nil, balancer.ErrBadResolverState
}
// This write doesn't need to take gsb.mu because this field never gets read
@@ -138,7 +147,7 @@ func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
// bw.Balancer field will never be forwarded to until this SwitchTo()
// function returns.
bw.Balancer = newBalancer
return nil
return bw, nil
}
// Returns nil if the graceful switch balancer is closed.
@@ -152,12 +161,32 @@ func (gsb *Balancer) latestBalancer() *balancerWrapper {
}
// UpdateClientConnState forwards the update to the latest balancer created.
//
// If the state's BalancerConfig is the config returned by a call to
// gracefulswitch.ParseConfig, then this function will automatically SwitchTo
// the balancer indicated by the config before forwarding its config to it, if
// necessary.
func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error {
// The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer()
gsbCfg, ok := state.BalancerConfig.(*lbConfig)
if ok {
// Switch to the child in the config unless it is already active.
if balToUpdate == nil || gsbCfg.childBuilder.Name() != balToUpdate.builder.Name() {
var err error
balToUpdate, err = gsb.switchTo(gsbCfg.childBuilder)
if err != nil {
return fmt.Errorf("could not switch to new child balancer: %w", err)
}
}
// Unwrap the child balancer's config.
state.BalancerConfig = gsbCfg.childConfig
}
if balToUpdate == nil {
return errBalancerClosed
}
// Perform this call without gsb.mu to prevent deadlocks if the child calls
// back into the channel. The latest balancer can never be closed during a
// call from the channel, even without gsb.mu held.
@@ -169,6 +198,10 @@ func (gsb *Balancer) ResolverError(err error) {
// The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer()
if balToUpdate == nil {
gsb.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(err),
})
return
}
// Perform this call without gsb.mu to prevent deadlocks if the child calls
@@ -261,7 +294,8 @@ func (gsb *Balancer) Close() {
// graceful switch logic.
type balancerWrapper struct {
balancer.Balancer
gsb *Balancer
gsb *Balancer
builder balancer.Builder
lastState balancer.State
subconns map[balancer.SubConn]bool // subconns created by this balancer

View File

@@ -65,7 +65,7 @@ type TruncatingMethodLogger struct {
callID uint64
idWithinCallGen *callIDGenerator
sink Sink // TODO(blog): make this plugable.
sink Sink // TODO(blog): make this pluggable.
}
// NewTruncatingMethodLogger returns a new truncating method logger.
@@ -80,7 +80,7 @@ func NewTruncatingMethodLogger(h, m uint64) *TruncatingMethodLogger {
callID: idGen.next(),
idWithinCallGen: &callIDGenerator{},
sink: DefaultSink, // TODO(blog): make it plugable.
sink: DefaultSink, // TODO(blog): make it pluggable.
}
}
@@ -397,7 +397,7 @@ func metadataKeyOmit(key string) bool {
switch key {
case "lb-token", ":path", ":authority", "content-encoding", "content-type", "user-agent", "te":
return true
case "grpc-trace-bin": // grpc-trace-bin is special because it's visiable to users.
case "grpc-trace-bin": // grpc-trace-bin is special because it's visible to users.
return false
}
return strings.HasPrefix(key, "grpc-")

View File

@@ -0,0 +1,255 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
"google.golang.org/grpc/connectivity"
)
// Channel represents a channel within channelz, which includes metrics and
// internal channelz data, such as channelz id, child list, etc.
type Channel struct {
Entity
// ID is the channelz id of this channel.
ID int64
// RefName is the human readable reference string of this channel.
RefName string
closeCalled bool
nestedChans map[int64]string
subChans map[int64]string
Parent *Channel
trace *ChannelTrace
// traceRefCount is the number of trace events that reference this channel.
// Non-zero traceRefCount means the trace of this channel cannot be deleted.
traceRefCount int32
ChannelMetrics ChannelMetrics
}
// Implemented to make Channel implement the Identifier interface used for
// nesting.
func (c *Channel) channelzIdentifier() {}
func (c *Channel) String() string {
if c.Parent == nil {
return fmt.Sprintf("Channel #%d", c.ID)
}
return fmt.Sprintf("%s Channel #%d", c.Parent, c.ID)
}
func (c *Channel) id() int64 {
return c.ID
}
func (c *Channel) SubChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.subChans)
}
func (c *Channel) NestedChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.nestedChans)
}
func (c *Channel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()
return c.trace.copy()
}
type ChannelMetrics struct {
// The current connectivity state of the channel.
State atomic.Pointer[connectivity.State]
// The target this channel originally tried to connect to. May be absent
Target atomic.Pointer[string]
// The number of calls started on the channel.
CallsStarted atomic.Int64
// The number of calls that have completed with an OK status.
CallsSucceeded atomic.Int64
// The number of calls that have a completed with a non-OK status.
CallsFailed atomic.Int64
// The last time a call was started on the channel.
LastCallStartedTimestamp atomic.Int64
}
// CopyFrom copies the metrics in o to c. For testing only.
func (c *ChannelMetrics) CopyFrom(o *ChannelMetrics) {
c.State.Store(o.State.Load())
c.Target.Store(o.Target.Load())
c.CallsStarted.Store(o.CallsStarted.Load())
c.CallsSucceeded.Store(o.CallsSucceeded.Load())
c.CallsFailed.Store(o.CallsFailed.Load())
c.LastCallStartedTimestamp.Store(o.LastCallStartedTimestamp.Load())
}
// Equal returns true iff the metrics of c are the same as the metrics of o.
// For testing only.
func (c *ChannelMetrics) Equal(o any) bool {
oc, ok := o.(*ChannelMetrics)
if !ok {
return false
}
if (c.State.Load() == nil) != (oc.State.Load() == nil) {
return false
}
if c.State.Load() != nil && *c.State.Load() != *oc.State.Load() {
return false
}
if (c.Target.Load() == nil) != (oc.Target.Load() == nil) {
return false
}
if c.Target.Load() != nil && *c.Target.Load() != *oc.Target.Load() {
return false
}
return c.CallsStarted.Load() == oc.CallsStarted.Load() &&
c.CallsFailed.Load() == oc.CallsFailed.Load() &&
c.CallsSucceeded.Load() == oc.CallsSucceeded.Load() &&
c.LastCallStartedTimestamp.Load() == oc.LastCallStartedTimestamp.Load()
}
func strFromPointer(s *string) string {
if s == nil {
return ""
}
return *s
}
func (c *ChannelMetrics) String() string {
return fmt.Sprintf("State: %v, Target: %s, CallsStarted: %v, CallsSucceeded: %v, CallsFailed: %v, LastCallStartedTimestamp: %v",
c.State.Load(), strFromPointer(c.Target.Load()), c.CallsStarted.Load(), c.CallsSucceeded.Load(), c.CallsFailed.Load(), c.LastCallStartedTimestamp.Load(),
)
}
func NewChannelMetricForTesting(state connectivity.State, target string, started, succeeded, failed, timestamp int64) *ChannelMetrics {
c := &ChannelMetrics{}
c.State.Store(&state)
c.Target.Store(&target)
c.CallsStarted.Store(started)
c.CallsSucceeded.Store(succeeded)
c.CallsFailed.Store(failed)
c.LastCallStartedTimestamp.Store(timestamp)
return c
}
func (c *Channel) addChild(id int64, e entry) {
switch v := e.(type) {
case *SubChannel:
c.subChans[id] = v.RefName
case *Channel:
c.nestedChans[id] = v.RefName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a channel", id, e)
}
}
func (c *Channel) deleteChild(id int64) {
delete(c.subChans, id)
delete(c.nestedChans, id)
c.deleteSelfIfReady()
}
func (c *Channel) triggerDelete() {
c.closeCalled = true
c.deleteSelfIfReady()
}
func (c *Channel) getParentID() int64 {
if c.Parent == nil {
return -1
}
return c.Parent.ID
}
// deleteSelfFromTree tries to delete the channel from the channelz entry relation tree, which means
// deleting the channel reference from its parent's child list.
//
// In order for a channel to be deleted from the tree, it must meet the criteria that, removal of the
// corresponding grpc object has been invoked, and the channel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (c *Channel) deleteSelfFromTree() (deleted bool) {
if !c.closeCalled || len(c.subChans)+len(c.nestedChans) != 0 {
return false
}
// not top channel
if c.Parent != nil {
c.Parent.deleteChild(c.ID)
}
return true
}
// deleteSelfFromMap checks whether it is valid to delete the channel from the map, which means
// deleting the channel from channelz's tracking entirely. Users can no longer use id to query the
// channel, and its memory will be garbage collected.
//
// The trace reference count of the channel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (c *Channel) deleteSelfFromMap() (delete bool) {
return c.getTraceRefCount() == 0
}
// deleteSelfIfReady tries to delete the channel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the channel from the entry relation tree, i.e. delete the channel reference from its
// parent's child list.
// 2. delete the channel from the map, i.e. delete the channel entirely from channelz. Lookup by id
// will return entry not found error.
func (c *Channel) deleteSelfIfReady() {
if !c.deleteSelfFromTree() {
return
}
if !c.deleteSelfFromMap() {
return
}
db.deleteEntry(c.ID)
c.trace.clear()
}
func (c *Channel) getChannelTrace() *ChannelTrace {
return c.trace
}
func (c *Channel) incrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, 1)
}
func (c *Channel) decrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, -1)
}
func (c *Channel) getTraceRefCount() int {
i := atomic.LoadInt32(&c.traceRefCount)
return int(i)
}
func (c *Channel) getRefName() string {
return c.RefName
}

View File

@@ -0,0 +1,402 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sort"
"sync"
"time"
)
// entry represents a node in the channelz database.
type entry interface {
// addChild adds a child e, whose channelz id is id to child list
addChild(id int64, e entry)
// deleteChild deletes a child with channelz id to be id from child list
deleteChild(id int64)
// triggerDelete tries to delete self from channelz database. However, if
// child list is not empty, then deletion from the database is on hold until
// the last child is deleted from database.
triggerDelete()
// deleteSelfIfReady check whether triggerDelete() has been called before,
// and whether child list is now empty. If both conditions are met, then
// delete self from database.
deleteSelfIfReady()
// getParentID returns parent ID of the entry. 0 value parent ID means no parent.
getParentID() int64
Entity
}
// channelMap is the storage data structure for channelz.
//
// Methods of channelMap can be divided in two two categories with respect to
// locking.
//
// 1. Methods acquire the global lock.
// 2. Methods that can only be called when global lock is held.
//
// A second type of method need always to be called inside a first type of method.
type channelMap struct {
mu sync.RWMutex
topLevelChannels map[int64]struct{}
channels map[int64]*Channel
subChannels map[int64]*SubChannel
sockets map[int64]*Socket
servers map[int64]*Server
}
func newChannelMap() *channelMap {
return &channelMap{
topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*Channel),
subChannels: make(map[int64]*SubChannel),
sockets: make(map[int64]*Socket),
servers: make(map[int64]*Server),
}
}
func (c *channelMap) addServer(id int64, s *Server) {
c.mu.Lock()
defer c.mu.Unlock()
s.cm = c
c.servers[id] = s
}
func (c *channelMap) addChannel(id int64, cn *Channel, isTopChannel bool, pid int64) {
c.mu.Lock()
defer c.mu.Unlock()
cn.trace.cm = c
c.channels[id] = cn
if isTopChannel {
c.topLevelChannels[id] = struct{}{}
} else if p := c.channels[pid]; p != nil {
p.addChild(id, cn)
} else {
logger.Infof("channel %d references invalid parent ID %d", id, pid)
}
}
func (c *channelMap) addSubChannel(id int64, sc *SubChannel, pid int64) {
c.mu.Lock()
defer c.mu.Unlock()
sc.trace.cm = c
c.subChannels[id] = sc
if p := c.channels[pid]; p != nil {
p.addChild(id, sc)
} else {
logger.Infof("subchannel %d references invalid parent ID %d", id, pid)
}
}
func (c *channelMap) addSocket(s *Socket) {
c.mu.Lock()
defer c.mu.Unlock()
s.cm = c
c.sockets[s.ID] = s
if s.Parent == nil {
logger.Infof("normal socket %d has no parent", s.ID)
}
s.Parent.(entry).addChild(s.ID, s)
}
// removeEntry triggers the removal of an entry, which may not indeed delete the
// entry, if it has to wait on the deletion of its children and until no other
// entity's channel trace references it. It may lead to a chain of entry
// deletion. For example, deleting the last socket of a gracefully shutting down
// server will lead to the server being also deleted.
func (c *channelMap) removeEntry(id int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.findEntry(id).triggerDelete()
}
// tracedChannel represents tracing operations which are present on both
// channels and subChannels.
type tracedChannel interface {
getChannelTrace() *ChannelTrace
incrTraceRefCount()
decrTraceRefCount()
getRefName() string
}
// c.mu must be held by the caller
func (c *channelMap) decrTraceRefCount(id int64) {
e := c.findEntry(id)
if v, ok := e.(tracedChannel); ok {
v.decrTraceRefCount()
e.deleteSelfIfReady()
}
}
// c.mu must be held by the caller.
func (c *channelMap) findEntry(id int64) entry {
if v, ok := c.channels[id]; ok {
return v
}
if v, ok := c.subChannels[id]; ok {
return v
}
if v, ok := c.servers[id]; ok {
return v
}
if v, ok := c.sockets[id]; ok {
return v
}
return &dummyEntry{idNotFound: id}
}
// c.mu must be held by the caller
//
// deleteEntry deletes an entry from the channelMap. Before calling this method,
// caller must check this entry is ready to be deleted, i.e removeEntry() has
// been called on it, and no children still exist.
func (c *channelMap) deleteEntry(id int64) entry {
if v, ok := c.sockets[id]; ok {
delete(c.sockets, id)
return v
}
if v, ok := c.subChannels[id]; ok {
delete(c.subChannels, id)
return v
}
if v, ok := c.channels[id]; ok {
delete(c.channels, id)
delete(c.topLevelChannels, id)
return v
}
if v, ok := c.servers[id]; ok {
delete(c.servers, id)
return v
}
return &dummyEntry{idNotFound: id}
}
func (c *channelMap) traceEvent(id int64, desc *TraceEvent) {
c.mu.Lock()
defer c.mu.Unlock()
child := c.findEntry(id)
childTC, ok := child.(tracedChannel)
if !ok {
return
}
childTC.getChannelTrace().append(&traceEvent{Desc: desc.Desc, Severity: desc.Severity, Timestamp: time.Now()})
if desc.Parent != nil {
parent := c.findEntry(child.getParentID())
var chanType RefChannelType
switch child.(type) {
case *Channel:
chanType = RefChannel
case *SubChannel:
chanType = RefSubChannel
}
if parentTC, ok := parent.(tracedChannel); ok {
parentTC.getChannelTrace().append(&traceEvent{
Desc: desc.Parent.Desc,
Severity: desc.Parent.Severity,
Timestamp: time.Now(),
RefID: id,
RefName: childTC.getRefName(),
RefType: chanType,
})
childTC.incrTraceRefCount()
}
}
}
type int64Slice []int64
func (s int64Slice) Len() int { return len(s) }
func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
func copyMap(m map[int64]string) map[int64]string {
n := make(map[int64]string)
for k, v := range m {
n[k] = v
}
return n
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *channelMap) getTopChannels(id int64, maxResults int) ([]*Channel, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
l := int64(len(c.topLevelChannels))
ids := make([]int64, 0, l)
for k := range c.topLevelChannels {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
end := true
var t []*Channel
for _, v := range ids[idx:] {
if len(t) == maxResults {
end = false
break
}
if cn, ok := c.channels[v]; ok {
t = append(t, cn)
}
}
return t, end
}
func (c *channelMap) getServers(id int64, maxResults int) ([]*Server, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
ids := make([]int64, 0, len(c.servers))
for k := range c.servers {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
end := true
var s []*Server
for _, v := range ids[idx:] {
if len(s) == maxResults {
end = false
break
}
if svr, ok := c.servers[v]; ok {
s = append(s, svr)
}
}
return s, end
}
func (c *channelMap) getServerSockets(id int64, startID int64, maxResults int) ([]*Socket, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
svr, ok := c.servers[id]
if !ok {
// server with id doesn't exist.
return nil, true
}
svrskts := svr.sockets
ids := make([]int64, 0, len(svrskts))
sks := make([]*Socket, 0, min(len(svrskts), maxResults))
for k := range svrskts {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= startID })
end := true
for _, v := range ids[idx:] {
if len(sks) == maxResults {
end = false
break
}
if ns, ok := c.sockets[v]; ok {
sks = append(sks, ns)
}
}
return sks, end
}
func (c *channelMap) getChannel(id int64) *Channel {
c.mu.RLock()
defer c.mu.RUnlock()
return c.channels[id]
}
func (c *channelMap) getSubChannel(id int64) *SubChannel {
c.mu.RLock()
defer c.mu.RUnlock()
return c.subChannels[id]
}
func (c *channelMap) getSocket(id int64) *Socket {
c.mu.RLock()
defer c.mu.RUnlock()
return c.sockets[id]
}
func (c *channelMap) getServer(id int64) *Server {
c.mu.RLock()
defer c.mu.RUnlock()
return c.servers[id]
}
type dummyEntry struct {
// dummyEntry is a fake entry to handle entry not found case.
idNotFound int64
Entity
}
func (d *dummyEntry) String() string {
return fmt.Sprintf("non-existent entity #%d", d.idNotFound)
}
func (d *dummyEntry) ID() int64 { return d.idNotFound }
func (d *dummyEntry) addChild(id int64, e entry) {
// Note: It is possible for a normal program to reach here under race
// condition. For example, there could be a race between ClientConn.Close()
// info being propagated to addrConn and http2Client. ClientConn.Close()
// cancel the context and result in http2Client to error. The error info is
// then caught by transport monitor and before addrConn.tearDown() is called
// in side ClientConn.Close(). Therefore, the addrConn will create a new
// transport. And when registering the new transport in channelz, its parent
// addrConn could have already been torn down and deleted from channelz
// tracking, and thus reach the code here.
logger.Infof("attempt to add child of type %T with id %d to a parent (id=%d) that doesn't currently exist", e, id, d.idNotFound)
}
func (d *dummyEntry) deleteChild(id int64) {
// It is possible for a normal program to reach here under race condition.
// Refer to the example described in addChild().
logger.Infof("attempt to delete child with id %d from a parent (id=%d) that doesn't currently exist", id, d.idNotFound)
}
func (d *dummyEntry) triggerDelete() {
logger.Warningf("attempt to delete an entry (id=%d) that doesn't currently exist", d.idNotFound)
}
func (*dummyEntry) deleteSelfIfReady() {
// code should not reach here. deleteSelfIfReady is always called on an existing entry.
}
func (*dummyEntry) getParentID() int64 {
return 0
}
// Entity is implemented by all channelz types.
type Entity interface {
isEntity()
fmt.Stringer
id() int64
}

View File

@@ -16,47 +16,32 @@
*
*/
// Package channelz defines APIs for enabling channelz service, entry
// Package channelz defines internal APIs for enabling channelz service, entry
// registration/deletion, and accessing channelz data. It also defines channelz
// metric struct formats.
//
// All APIs in this package are experimental.
package channelz
import (
"errors"
"sort"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
)
const (
defaultMaxTraceEntry int32 = 30
)
var (
// IDGen is the global channelz entity ID generator. It should not be used
// outside this package except by tests.
IDGen IDGenerator
db dbWrapper
// EntryPerPage defines the number of channelz entries to be shown on a web page.
EntryPerPage = int64(50)
curState int32
maxTraceEntry = defaultMaxTraceEntry
db *channelMap = newChannelMap()
// EntriesPerPage defines the number of channelz entries to be shown on a web page.
EntriesPerPage = 50
curState int32
)
// TurnOn turns on channelz data collection.
func TurnOn() {
if !IsOn() {
db.set(newChannelMap())
IDGen.Reset()
atomic.StoreInt32(&curState, 1)
}
atomic.StoreInt32(&curState, 1)
}
func init() {
@@ -70,49 +55,15 @@ func IsOn() bool {
return atomic.LoadInt32(&curState) == 1
}
// SetMaxTraceEntry sets maximum number of trace entry per entity (i.e. channel/subchannel).
// Setting it to 0 will disable channel tracing.
func SetMaxTraceEntry(i int32) {
atomic.StoreInt32(&maxTraceEntry, i)
}
// ResetMaxTraceEntryToDefault resets the maximum number of trace entry per entity to default.
func ResetMaxTraceEntryToDefault() {
atomic.StoreInt32(&maxTraceEntry, defaultMaxTraceEntry)
}
func getMaxTraceEntry() int {
i := atomic.LoadInt32(&maxTraceEntry)
return int(i)
}
// dbWarpper wraps around a reference to internal channelz data storage, and
// provide synchronized functionality to set and get the reference.
type dbWrapper struct {
mu sync.RWMutex
DB *channelMap
}
func (d *dbWrapper) set(db *channelMap) {
d.mu.Lock()
d.DB = db
d.mu.Unlock()
}
func (d *dbWrapper) get() *channelMap {
d.mu.RLock()
defer d.mu.RUnlock()
return d.DB
}
// GetTopChannels returns a slice of top channel's ChannelMetric, along with a
// boolean indicating whether there's more top channels to be queried for.
//
// The arg id specifies that only top channel with id at or above it will be included
// in the result. The returned slice is up to a length of the arg maxResults or
// EntryPerPage if maxResults is zero, and is sorted in ascending id order.
func GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
return db.get().GetTopChannels(id, maxResults)
// The arg id specifies that only top channel with id at or above it will be
// included in the result. The returned slice is up to a length of the arg
// maxResults or EntriesPerPage if maxResults is zero, and is sorted in ascending
// id order.
func GetTopChannels(id int64, maxResults int) ([]*Channel, bool) {
return db.getTopChannels(id, maxResults)
}
// GetServers returns a slice of server's ServerMetric, along with a
@@ -120,73 +71,69 @@ func GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
//
// The arg id specifies that only server with id at or above it will be included
// in the result. The returned slice is up to a length of the arg maxResults or
// EntryPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServers(id int64, maxResults int64) ([]*ServerMetric, bool) {
return db.get().GetServers(id, maxResults)
// EntriesPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServers(id int64, maxResults int) ([]*Server, bool) {
return db.getServers(id, maxResults)
}
// GetServerSockets returns a slice of server's (identified by id) normal socket's
// SocketMetric, along with a boolean indicating whether there's more sockets to
// SocketMetrics, along with a boolean indicating whether there's more sockets to
// be queried for.
//
// The arg startID specifies that only sockets with id at or above it will be
// included in the result. The returned slice is up to a length of the arg maxResults
// or EntryPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) {
return db.get().GetServerSockets(id, startID, maxResults)
// or EntriesPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServerSockets(id int64, startID int64, maxResults int) ([]*Socket, bool) {
return db.getServerSockets(id, startID, maxResults)
}
// GetChannel returns the ChannelMetric for the channel (identified by id).
func GetChannel(id int64) *ChannelMetric {
return db.get().GetChannel(id)
// GetChannel returns the Channel for the channel (identified by id).
func GetChannel(id int64) *Channel {
return db.getChannel(id)
}
// GetSubChannel returns the SubChannelMetric for the subchannel (identified by id).
func GetSubChannel(id int64) *SubChannelMetric {
return db.get().GetSubChannel(id)
// GetSubChannel returns the SubChannel for the subchannel (identified by id).
func GetSubChannel(id int64) *SubChannel {
return db.getSubChannel(id)
}
// GetSocket returns the SocketInternalMetric for the socket (identified by id).
func GetSocket(id int64) *SocketMetric {
return db.get().GetSocket(id)
// GetSocket returns the Socket for the socket (identified by id).
func GetSocket(id int64) *Socket {
return db.getSocket(id)
}
// GetServer returns the ServerMetric for the server (identified by id).
func GetServer(id int64) *ServerMetric {
return db.get().GetServer(id)
func GetServer(id int64) *Server {
return db.getServer(id)
}
// RegisterChannel registers the given channel c in the channelz database with
// ref as its reference name, and adds it to the child list of its parent
// (identified by pid). pid == nil means no parent.
// target as its target and reference name, and adds it to the child list of its
// parent. parent == nil means no parent.
//
// Returns a unique channelz identifier assigned to this channel.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier {
func RegisterChannel(parent *Channel, target string) *Channel {
id := IDGen.genID()
var parent int64
isTopChannel := true
if pid != nil {
isTopChannel = false
parent = pid.Int()
}
if !IsOn() {
return newIdentifer(RefChannel, id, pid)
return &Channel{ID: id}
}
cn := &channel{
refName: ref,
c: c,
subChans: make(map[int64]string),
isTopChannel := parent == nil
cn := &Channel{
ID: id,
RefName: target,
nestedChans: make(map[int64]string),
id: id,
pid: parent,
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
subChans: make(map[int64]string),
Parent: parent,
trace: &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())},
}
db.get().addChannel(id, cn, isTopChannel, parent)
return newIdentifer(RefChannel, id, pid)
cn.ChannelMetrics.Target.Store(&target)
db.addChannel(id, cn, isTopChannel, cn.getParentID())
return cn
}
// RegisterSubChannel registers the given subChannel c in the channelz database
@@ -196,555 +143,67 @@ func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier {
// Returns a unique channelz identifier assigned to this subChannel.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterSubChannel(c Channel, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a SubChannel's parent id cannot be nil")
}
func RegisterSubChannel(parent *Channel, ref string) *SubChannel {
id := IDGen.genID()
if !IsOn() {
return newIdentifer(RefSubChannel, id, pid), nil
sc := &SubChannel{
ID: id,
RefName: ref,
parent: parent,
}
sc := &subChannel{
refName: ref,
c: c,
sockets: make(map[int64]string),
id: id,
pid: pid.Int(),
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
if !IsOn() {
return sc
}
db.get().addSubChannel(id, sc, pid.Int())
return newIdentifer(RefSubChannel, id, pid), nil
sc.sockets = make(map[int64]string)
sc.trace = &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())}
db.addSubChannel(id, sc, parent.ID)
return sc
}
// RegisterServer registers the given server s in channelz database. It returns
// the unique channelz tracking id assigned to this server.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterServer(s Server, ref string) *Identifier {
func RegisterServer(ref string) *Server {
id := IDGen.genID()
if !IsOn() {
return newIdentifer(RefServer, id, nil)
return &Server{ID: id}
}
svr := &server{
refName: ref,
s: s,
svr := &Server{
RefName: ref,
sockets: make(map[int64]string),
listenSockets: make(map[int64]string),
id: id,
ID: id,
}
db.get().addServer(id, svr)
return newIdentifer(RefServer, id, nil)
db.addServer(id, svr)
return svr
}
// RegisterListenSocket registers the given listen socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this listen socket.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterListenSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a ListenSocket's parent id cannot be 0")
}
id := IDGen.genID()
if !IsOn() {
return newIdentifer(RefListenSocket, id, pid), nil
}
ls := &listenSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addListenSocket(id, ls, pid.Int())
return newIdentifer(RefListenSocket, id, pid), nil
}
// RegisterNormalSocket registers the given normal socket s in channelz database
// RegisterSocket registers the given normal socket s in channelz database
// with ref as its reference name, and adds it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this normal socket.
// (identified by skt.Parent, which must be set). It returns the unique channelz
// tracking id assigned to this normal socket.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterNormalSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a NormalSocket's parent id cannot be 0")
func RegisterSocket(skt *Socket) *Socket {
skt.ID = IDGen.genID()
if IsOn() {
db.addSocket(skt)
}
id := IDGen.genID()
if !IsOn() {
return newIdentifer(RefNormalSocket, id, pid), nil
}
ns := &normalSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addNormalSocket(id, ns, pid.Int())
return newIdentifer(RefNormalSocket, id, pid), nil
return skt
}
// RemoveEntry removes an entry with unique channelz tracking id to be id from
// channelz database.
//
// If channelz is not turned ON, this function is a no-op.
func RemoveEntry(id *Identifier) {
func RemoveEntry(id int64) {
if !IsOn() {
return
}
db.get().removeEntry(id.Int())
}
// TraceEventDesc is what the caller of AddTraceEvent should provide to describe
// the event to be added to the channel trace.
//
// The Parent field is optional. It is used for an event that will be recorded
// in the entity's parent trace.
type TraceEventDesc struct {
Desc string
Severity Severity
Parent *TraceEventDesc
}
// AddTraceEvent adds trace related to the entity with specified id, using the
// provided TraceEventDesc.
//
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, id *Identifier, depth int, desc *TraceEventDesc) {
// Log only the trace description associated with the bottom most entity.
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, withParens(id)+desc.Desc)
case CtWarning:
l.WarningDepth(depth+1, withParens(id)+desc.Desc)
case CtError:
l.ErrorDepth(depth+1, withParens(id)+desc.Desc)
}
if getMaxTraceEntry() == 0 {
return
}
if IsOn() {
db.get().traceEvent(id.Int(), desc)
}
}
// channelMap is the storage data structure for channelz.
// Methods of channelMap can be divided in two two categories with respect to locking.
// 1. Methods acquire the global lock.
// 2. Methods that can only be called when global lock is held.
// A second type of method need always to be called inside a first type of method.
type channelMap struct {
mu sync.RWMutex
topLevelChannels map[int64]struct{}
servers map[int64]*server
channels map[int64]*channel
subChannels map[int64]*subChannel
listenSockets map[int64]*listenSocket
normalSockets map[int64]*normalSocket
}
func newChannelMap() *channelMap {
return &channelMap{
topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*channel),
listenSockets: make(map[int64]*listenSocket),
normalSockets: make(map[int64]*normalSocket),
servers: make(map[int64]*server),
subChannels: make(map[int64]*subChannel),
}
}
func (c *channelMap) addServer(id int64, s *server) {
c.mu.Lock()
s.cm = c
c.servers[id] = s
c.mu.Unlock()
}
func (c *channelMap) addChannel(id int64, cn *channel, isTopChannel bool, pid int64) {
c.mu.Lock()
cn.cm = c
cn.trace.cm = c
c.channels[id] = cn
if isTopChannel {
c.topLevelChannels[id] = struct{}{}
} else {
c.findEntry(pid).addChild(id, cn)
}
c.mu.Unlock()
}
func (c *channelMap) addSubChannel(id int64, sc *subChannel, pid int64) {
c.mu.Lock()
sc.cm = c
sc.trace.cm = c
c.subChannels[id] = sc
c.findEntry(pid).addChild(id, sc)
c.mu.Unlock()
}
func (c *channelMap) addListenSocket(id int64, ls *listenSocket, pid int64) {
c.mu.Lock()
ls.cm = c
c.listenSockets[id] = ls
c.findEntry(pid).addChild(id, ls)
c.mu.Unlock()
}
func (c *channelMap) addNormalSocket(id int64, ns *normalSocket, pid int64) {
c.mu.Lock()
ns.cm = c
c.normalSockets[id] = ns
c.findEntry(pid).addChild(id, ns)
c.mu.Unlock()
}
// removeEntry triggers the removal of an entry, which may not indeed delete the entry, if it has to
// wait on the deletion of its children and until no other entity's channel trace references it.
// It may lead to a chain of entry deletion. For example, deleting the last socket of a gracefully
// shutting down server will lead to the server being also deleted.
func (c *channelMap) removeEntry(id int64) {
c.mu.Lock()
c.findEntry(id).triggerDelete()
c.mu.Unlock()
}
// c.mu must be held by the caller
func (c *channelMap) decrTraceRefCount(id int64) {
e := c.findEntry(id)
if v, ok := e.(tracedChannel); ok {
v.decrTraceRefCount()
e.deleteSelfIfReady()
}
}
// c.mu must be held by the caller.
func (c *channelMap) findEntry(id int64) entry {
var v entry
var ok bool
if v, ok = c.channels[id]; ok {
return v
}
if v, ok = c.subChannels[id]; ok {
return v
}
if v, ok = c.servers[id]; ok {
return v
}
if v, ok = c.listenSockets[id]; ok {
return v
}
if v, ok = c.normalSockets[id]; ok {
return v
}
return &dummyEntry{idNotFound: id}
}
// c.mu must be held by the caller
// deleteEntry simply deletes an entry from the channelMap. Before calling this
// method, caller must check this entry is ready to be deleted, i.e removeEntry()
// has been called on it, and no children still exist.
// Conditionals are ordered by the expected frequency of deletion of each entity
// type, in order to optimize performance.
func (c *channelMap) deleteEntry(id int64) {
var ok bool
if _, ok = c.normalSockets[id]; ok {
delete(c.normalSockets, id)
return
}
if _, ok = c.subChannels[id]; ok {
delete(c.subChannels, id)
return
}
if _, ok = c.channels[id]; ok {
delete(c.channels, id)
delete(c.topLevelChannels, id)
return
}
if _, ok = c.listenSockets[id]; ok {
delete(c.listenSockets, id)
return
}
if _, ok = c.servers[id]; ok {
delete(c.servers, id)
return
}
}
func (c *channelMap) traceEvent(id int64, desc *TraceEventDesc) {
c.mu.Lock()
child := c.findEntry(id)
childTC, ok := child.(tracedChannel)
if !ok {
c.mu.Unlock()
return
}
childTC.getChannelTrace().append(&TraceEvent{Desc: desc.Desc, Severity: desc.Severity, Timestamp: time.Now()})
if desc.Parent != nil {
parent := c.findEntry(child.getParentID())
var chanType RefChannelType
switch child.(type) {
case *channel:
chanType = RefChannel
case *subChannel:
chanType = RefSubChannel
}
if parentTC, ok := parent.(tracedChannel); ok {
parentTC.getChannelTrace().append(&TraceEvent{
Desc: desc.Parent.Desc,
Severity: desc.Parent.Severity,
Timestamp: time.Now(),
RefID: id,
RefName: childTC.getRefName(),
RefType: chanType,
})
childTC.incrTraceRefCount()
}
}
c.mu.Unlock()
}
type int64Slice []int64
func (s int64Slice) Len() int { return len(s) }
func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
func copyMap(m map[int64]string) map[int64]string {
n := make(map[int64]string)
for k, v := range m {
n[k] = v
}
return n
}
func min(a, b int64) int64 {
if a < b {
return a
}
return b
}
func (c *channelMap) GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
c.mu.RLock()
l := int64(len(c.topLevelChannels))
ids := make([]int64, 0, l)
cns := make([]*channel, 0, min(l, maxResults))
for k := range c.topLevelChannels {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := int64(0)
var end bool
var t []*ChannelMetric
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if cn, ok := c.channels[v]; ok {
cns = append(cns, cn)
t = append(t, &ChannelMetric{
NestedChans: copyMap(cn.nestedChans),
SubChans: copyMap(cn.subChans),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, cn := range cns {
t[i].ChannelData = cn.c.ChannelzMetric()
t[i].ID = cn.id
t[i].RefName = cn.refName
t[i].Trace = cn.trace.dumpData()
}
return t, end
}
func (c *channelMap) GetServers(id, maxResults int64) ([]*ServerMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
c.mu.RLock()
l := int64(len(c.servers))
ids := make([]int64, 0, l)
ss := make([]*server, 0, min(l, maxResults))
for k := range c.servers {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := int64(0)
var end bool
var s []*ServerMetric
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if svr, ok := c.servers[v]; ok {
ss = append(ss, svr)
s = append(s, &ServerMetric{
ListenSockets: copyMap(svr.listenSockets),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, svr := range ss {
s[i].ServerData = svr.s.ChannelzMetric()
s[i].ID = svr.id
s[i].RefName = svr.refName
}
return s, end
}
func (c *channelMap) GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
var svr *server
var ok bool
c.mu.RLock()
if svr, ok = c.servers[id]; !ok {
// server with id doesn't exist.
c.mu.RUnlock()
return nil, true
}
svrskts := svr.sockets
l := int64(len(svrskts))
ids := make([]int64, 0, l)
sks := make([]*normalSocket, 0, min(l, maxResults))
for k := range svrskts {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= startID })
count := int64(0)
var end bool
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if ns, ok := c.normalSockets[v]; ok {
sks = append(sks, ns)
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
s := make([]*SocketMetric, 0, len(sks))
for _, ns := range sks {
sm := &SocketMetric{}
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
s = append(s, sm)
}
return s, end
}
func (c *channelMap) GetChannel(id int64) *ChannelMetric {
cm := &ChannelMetric{}
var cn *channel
var ok bool
c.mu.RLock()
if cn, ok = c.channels[id]; !ok {
// channel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.NestedChans = copyMap(cn.nestedChans)
cm.SubChans = copyMap(cn.subChans)
// cn.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of cn.c when
// holding the lock to prevent potential data race.
chanCopy := cn.c
c.mu.RUnlock()
cm.ChannelData = chanCopy.ChannelzMetric()
cm.ID = cn.id
cm.RefName = cn.refName
cm.Trace = cn.trace.dumpData()
return cm
}
func (c *channelMap) GetSubChannel(id int64) *SubChannelMetric {
cm := &SubChannelMetric{}
var sc *subChannel
var ok bool
c.mu.RLock()
if sc, ok = c.subChannels[id]; !ok {
// subchannel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.Sockets = copyMap(sc.sockets)
// sc.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of sc.c when
// holding the lock to prevent potential data race.
chanCopy := sc.c
c.mu.RUnlock()
cm.ChannelData = chanCopy.ChannelzMetric()
cm.ID = sc.id
cm.RefName = sc.refName
cm.Trace = sc.trace.dumpData()
return cm
}
func (c *channelMap) GetSocket(id int64) *SocketMetric {
sm := &SocketMetric{}
c.mu.RLock()
if ls, ok := c.listenSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ls.s.ChannelzMetric()
sm.ID = ls.id
sm.RefName = ls.refName
return sm
}
if ns, ok := c.normalSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
return sm
}
c.mu.RUnlock()
return nil
}
func (c *channelMap) GetServer(id int64) *ServerMetric {
sm := &ServerMetric{}
var svr *server
var ok bool
c.mu.RLock()
if svr, ok = c.servers[id]; !ok {
c.mu.RUnlock()
return nil
}
sm.ListenSockets = copyMap(svr.listenSockets)
c.mu.RUnlock()
sm.ID = svr.id
sm.RefName = svr.refName
sm.ServerData = svr.s.ChannelzMetric()
return sm
db.removeEntry(id)
}
// IDGenerator is an incrementing atomic that tracks IDs for channelz entities.
@@ -761,3 +220,11 @@ func (i *IDGenerator) Reset() {
func (i *IDGenerator) genID() int64 {
return atomic.AddInt64(&i.id, 1)
}
// Identifier is an opaque channelz identifier used to expose channelz symbols
// outside of grpc. Currently only implemented by Channel since no other
// types require exposure outside grpc.
type Identifier interface {
Entity
channelzIdentifier()
}

View File

@@ -1,75 +0,0 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import "fmt"
// Identifier is an opaque identifier which uniquely identifies an entity in the
// channelz database.
type Identifier struct {
typ RefChannelType
id int64
str string
pid *Identifier
}
// Type returns the entity type corresponding to id.
func (id *Identifier) Type() RefChannelType {
return id.typ
}
// Int returns the integer identifier corresponding to id.
func (id *Identifier) Int() int64 {
return id.id
}
// String returns a string representation of the entity corresponding to id.
//
// This includes some information about the parent as well. Examples:
// Top-level channel: [Channel #channel-number]
// Nested channel: [Channel #parent-channel-number Channel #channel-number]
// Sub channel: [Channel #parent-channel SubChannel #subchannel-number]
func (id *Identifier) String() string {
return id.str
}
// Equal returns true if other is the same as id.
func (id *Identifier) Equal(other *Identifier) bool {
if (id != nil) != (other != nil) {
return false
}
if id == nil && other == nil {
return true
}
return id.typ == other.typ && id.id == other.id && id.pid == other.pid
}
// NewIdentifierForTesting returns a new opaque identifier to be used only for
// testing purposes.
func NewIdentifierForTesting(typ RefChannelType, id int64, pid *Identifier) *Identifier {
return newIdentifer(typ, id, pid)
}
func newIdentifer(typ RefChannelType, id int64, pid *Identifier) *Identifier {
str := fmt.Sprintf("%s #%d", typ, id)
if pid != nil {
str = fmt.Sprintf("%s %s", pid, str)
}
return &Identifier{typ: typ, id: id, str: str, pid: pid}
}

View File

@@ -26,53 +26,49 @@ import (
var logger = grpclog.Component("channelz")
func withParens(id *Identifier) string {
return "[" + id.String() + "] "
}
// Info logs and adds a trace event if channelz is on.
func Info(l grpclog.DepthLoggerV2, id *Identifier, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Info(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...),
Severity: CtInfo,
})
}
// Infof logs and adds a trace event if channelz is on.
func Infof(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Infof(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...),
Severity: CtInfo,
})
}
// Warning logs and adds a trace event if channelz is on.
func Warning(l grpclog.DepthLoggerV2, id *Identifier, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Warning(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...),
Severity: CtWarning,
})
}
// Warningf logs and adds a trace event if channelz is on.
func Warningf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Warningf(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...),
Severity: CtWarning,
})
}
// Error logs and adds a trace event if channelz is on.
func Error(l grpclog.DepthLoggerV2, id *Identifier, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Error(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...),
Severity: CtError,
})
}
// Errorf logs and adds a trace event if channelz is on.
func Errorf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
func Errorf(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...),
Severity: CtError,
})

View File

@@ -0,0 +1,119 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
)
// Server is the channelz representation of a server.
type Server struct {
Entity
ID int64
RefName string
ServerMetrics ServerMetrics
closeCalled bool
sockets map[int64]string
listenSockets map[int64]string
cm *channelMap
}
// ServerMetrics defines a struct containing metrics for servers.
type ServerMetrics struct {
// The number of incoming calls started on the server.
CallsStarted atomic.Int64
// The number of incoming calls that have completed with an OK status.
CallsSucceeded atomic.Int64
// The number of incoming calls that have a completed with a non-OK status.
CallsFailed atomic.Int64
// The last time a call was started on the server.
LastCallStartedTimestamp atomic.Int64
}
// NewServerMetricsForTesting returns an initialized ServerMetrics.
func NewServerMetricsForTesting(started, succeeded, failed, timestamp int64) *ServerMetrics {
sm := &ServerMetrics{}
sm.CallsStarted.Store(started)
sm.CallsSucceeded.Store(succeeded)
sm.CallsFailed.Store(failed)
sm.LastCallStartedTimestamp.Store(timestamp)
return sm
}
func (sm *ServerMetrics) CopyFrom(o *ServerMetrics) {
sm.CallsStarted.Store(o.CallsStarted.Load())
sm.CallsSucceeded.Store(o.CallsSucceeded.Load())
sm.CallsFailed.Store(o.CallsFailed.Load())
sm.LastCallStartedTimestamp.Store(o.LastCallStartedTimestamp.Load())
}
// ListenSockets returns the listening sockets for s.
func (s *Server) ListenSockets() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(s.listenSockets)
}
// String returns a printable description of s.
func (s *Server) String() string {
return fmt.Sprintf("Server #%d", s.ID)
}
func (s *Server) id() int64 {
return s.ID
}
func (s *Server) addChild(id int64, e entry) {
switch v := e.(type) {
case *Socket:
switch v.SocketType {
case SocketTypeNormal:
s.sockets[id] = v.RefName
case SocketTypeListen:
s.listenSockets[id] = v.RefName
}
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a server", id, e)
}
}
func (s *Server) deleteChild(id int64) {
delete(s.sockets, id)
delete(s.listenSockets, id)
s.deleteSelfIfReady()
}
func (s *Server) triggerDelete() {
s.closeCalled = true
s.deleteSelfIfReady()
}
func (s *Server) deleteSelfIfReady() {
if !s.closeCalled || len(s.sockets)+len(s.listenSockets) != 0 {
return
}
s.cm.deleteEntry(s.ID)
}
func (s *Server) getParentID() int64 {
return 0
}

View File

@@ -0,0 +1,130 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"net"
"sync/atomic"
"google.golang.org/grpc/credentials"
)
// SocketMetrics defines the struct that the implementor of Socket interface
// should return from ChannelzMetric().
type SocketMetrics struct {
// The number of streams that have been started.
StreamsStarted atomic.Int64
// The number of streams that have ended successfully:
// On client side, receiving frame with eos bit set.
// On server side, sending frame with eos bit set.
StreamsSucceeded atomic.Int64
// The number of streams that have ended unsuccessfully:
// On client side, termination without receiving frame with eos bit set.
// On server side, termination without sending frame with eos bit set.
StreamsFailed atomic.Int64
// The number of messages successfully sent on this socket.
MessagesSent atomic.Int64
MessagesReceived atomic.Int64
// The number of keep alives sent. This is typically implemented with HTTP/2
// ping messages.
KeepAlivesSent atomic.Int64
// The last time a stream was created by this endpoint. Usually unset for
// servers.
LastLocalStreamCreatedTimestamp atomic.Int64
// The last time a stream was created by the remote endpoint. Usually unset
// for clients.
LastRemoteStreamCreatedTimestamp atomic.Int64
// The last time a message was sent by this endpoint.
LastMessageSentTimestamp atomic.Int64
// The last time a message was received by this endpoint.
LastMessageReceivedTimestamp atomic.Int64
}
// EphemeralSocketMetrics are metrics that change rapidly and are tracked
// outside of channelz.
type EphemeralSocketMetrics struct {
// The amount of window, granted to the local endpoint by the remote endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
LocalFlowControlWindow int64
// The amount of window, granted to the remote endpoint by the local endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
RemoteFlowControlWindow int64
}
type SocketType string
const (
SocketTypeNormal = "NormalSocket"
SocketTypeListen = "ListenSocket"
)
type Socket struct {
Entity
SocketType SocketType
ID int64
Parent Entity
cm *channelMap
SocketMetrics SocketMetrics
EphemeralMetrics func() *EphemeralSocketMetrics
RefName string
// The locally bound address. Immutable.
LocalAddr net.Addr
// The remote bound address. May be absent. Immutable.
RemoteAddr net.Addr
// Optional, represents the name of the remote endpoint, if different than
// the original target name. Immutable.
RemoteName string
// Immutable.
SocketOptions *SocketOptionData
// Immutable.
Security credentials.ChannelzSecurityValue
}
func (ls *Socket) String() string {
return fmt.Sprintf("%s %s #%d", ls.Parent, ls.SocketType, ls.ID)
}
func (ls *Socket) id() int64 {
return ls.ID
}
func (ls *Socket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a listen socket", id, e)
}
func (ls *Socket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a listen socket", id)
}
func (ls *Socket) triggerDelete() {
ls.cm.deleteEntry(ls.ID)
ls.Parent.(entry).deleteChild(ls.ID)
}
func (ls *Socket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a listen socket")
}
func (ls *Socket) getParentID() int64 {
return ls.Parent.id()
}

View File

@@ -0,0 +1,151 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
)
// SubChannel is the channelz representation of a subchannel.
type SubChannel struct {
Entity
// ID is the channelz id of this subchannel.
ID int64
// RefName is the human readable reference string of this subchannel.
RefName string
closeCalled bool
sockets map[int64]string
parent *Channel
trace *ChannelTrace
traceRefCount int32
ChannelMetrics ChannelMetrics
}
func (sc *SubChannel) String() string {
return fmt.Sprintf("%s SubChannel #%d", sc.parent, sc.ID)
}
func (sc *SubChannel) id() int64 {
return sc.ID
}
func (sc *SubChannel) Sockets() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(sc.sockets)
}
func (sc *SubChannel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()
return sc.trace.copy()
}
func (sc *SubChannel) addChild(id int64, e entry) {
if v, ok := e.(*Socket); ok && v.SocketType == SocketTypeNormal {
sc.sockets[id] = v.RefName
} else {
logger.Errorf("cannot add a child (id = %d) of type %T to a subChannel", id, e)
}
}
func (sc *SubChannel) deleteChild(id int64) {
delete(sc.sockets, id)
sc.deleteSelfIfReady()
}
func (sc *SubChannel) triggerDelete() {
sc.closeCalled = true
sc.deleteSelfIfReady()
}
func (sc *SubChannel) getParentID() int64 {
return sc.parent.ID
}
// deleteSelfFromTree tries to delete the subchannel from the channelz entry relation tree, which
// means deleting the subchannel reference from its parent's child list.
//
// In order for a subchannel to be deleted from the tree, it must meet the criteria that, removal of
// the corresponding grpc object has been invoked, and the subchannel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (sc *SubChannel) deleteSelfFromTree() (deleted bool) {
if !sc.closeCalled || len(sc.sockets) != 0 {
return false
}
sc.parent.deleteChild(sc.ID)
return true
}
// deleteSelfFromMap checks whether it is valid to delete the subchannel from the map, which means
// deleting the subchannel from channelz's tracking entirely. Users can no longer use id to query
// the subchannel, and its memory will be garbage collected.
//
// The trace reference count of the subchannel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (sc *SubChannel) deleteSelfFromMap() (delete bool) {
return sc.getTraceRefCount() == 0
}
// deleteSelfIfReady tries to delete the subchannel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the subchannel from the entry relation tree, i.e. delete the subchannel reference from
// its parent's child list.
// 2. delete the subchannel from the map, i.e. delete the subchannel entirely from channelz. Lookup
// by id will return entry not found error.
func (sc *SubChannel) deleteSelfIfReady() {
if !sc.deleteSelfFromTree() {
return
}
if !sc.deleteSelfFromMap() {
return
}
db.deleteEntry(sc.ID)
sc.trace.clear()
}
func (sc *SubChannel) getChannelTrace() *ChannelTrace {
return sc.trace
}
func (sc *SubChannel) incrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, 1)
}
func (sc *SubChannel) decrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, -1)
}
func (sc *SubChannel) getTraceRefCount() int {
i := atomic.LoadInt32(&sc.traceRefCount)
return int(i)
}
func (sc *SubChannel) getRefName() string {
return sc.RefName
}

View File

@@ -49,3 +49,17 @@ func (s *SocketOptionData) Getsockopt(fd uintptr) {
s.TCPInfo = v
}
}
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(socket any) *SocketOptionData {
c, ok := socket.(syscall.Conn)
if !ok {
return nil
}
data := &SocketOptionData{}
if rawConn, err := c.SyscallConn(); err == nil {
rawConn.Control(data.Getsockopt)
return data
}
return nil
}

View File

@@ -1,5 +1,4 @@
//go:build !linux
// +build !linux
/*
*
@@ -41,3 +40,8 @@ func (s *SocketOptionData) Getsockopt(fd uintptr) {
logger.Warning("Channelz: socket options are not supported on non-linux environments")
})
}
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(c any) *SocketOptionData {
return nil
}

View File

@@ -0,0 +1,204 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc/grpclog"
)
const (
defaultMaxTraceEntry int32 = 30
)
var maxTraceEntry = defaultMaxTraceEntry
// SetMaxTraceEntry sets maximum number of trace entries per entity (i.e.
// channel/subchannel). Setting it to 0 will disable channel tracing.
func SetMaxTraceEntry(i int32) {
atomic.StoreInt32(&maxTraceEntry, i)
}
// ResetMaxTraceEntryToDefault resets the maximum number of trace entries per
// entity to default.
func ResetMaxTraceEntryToDefault() {
atomic.StoreInt32(&maxTraceEntry, defaultMaxTraceEntry)
}
func getMaxTraceEntry() int {
i := atomic.LoadInt32(&maxTraceEntry)
return int(i)
}
// traceEvent is an internal representation of a single trace event
type traceEvent struct {
// Desc is a simple description of the trace event.
Desc string
// Severity states the severity of this trace event.
Severity Severity
// Timestamp is the event time.
Timestamp time.Time
// RefID is the id of the entity that gets referenced in the event. RefID is 0 if no other entity is
// involved in this event.
// e.g. SubChannel (id: 4[]) Created. --> RefID = 4, RefName = "" (inside [])
RefID int64
// RefName is the reference name for the entity that gets referenced in the event.
RefName string
// RefType indicates the referenced entity type, i.e Channel or SubChannel.
RefType RefChannelType
}
// TraceEvent is what the caller of AddTraceEvent should provide to describe the
// event to be added to the channel trace.
//
// The Parent field is optional. It is used for an event that will be recorded
// in the entity's parent trace.
type TraceEvent struct {
Desc string
Severity Severity
Parent *TraceEvent
}
type ChannelTrace struct {
cm *channelMap
clearCalled bool
CreationTime time.Time
EventNum int64
mu sync.Mutex
Events []*traceEvent
}
func (c *ChannelTrace) copy() *ChannelTrace {
return &ChannelTrace{
CreationTime: c.CreationTime,
EventNum: c.EventNum,
Events: append(([]*traceEvent)(nil), c.Events...),
}
}
func (c *ChannelTrace) append(e *traceEvent) {
c.mu.Lock()
if len(c.Events) == getMaxTraceEntry() {
del := c.Events[0]
c.Events = c.Events[1:]
if del.RefID != 0 {
// start recursive cleanup in a goroutine to not block the call originated from grpc.
go func() {
// need to acquire c.cm.mu lock to call the unlocked attemptCleanup func.
c.cm.mu.Lock()
c.cm.decrTraceRefCount(del.RefID)
c.cm.mu.Unlock()
}()
}
}
e.Timestamp = time.Now()
c.Events = append(c.Events, e)
c.EventNum++
c.mu.Unlock()
}
func (c *ChannelTrace) clear() {
if c.clearCalled {
return
}
c.clearCalled = true
c.mu.Lock()
for _, e := range c.Events {
if e.RefID != 0 {
// caller should have already held the c.cm.mu lock.
c.cm.decrTraceRefCount(e.RefID)
}
}
c.mu.Unlock()
}
// Severity is the severity level of a trace event.
// The canonical enumeration of all valid values is here:
// https://github.com/grpc/grpc-proto/blob/9b13d199cc0d4703c7ea26c9c330ba695866eb23/grpc/channelz/v1/channelz.proto#L126.
type Severity int
const (
// CtUnknown indicates unknown severity of a trace event.
CtUnknown Severity = iota
// CtInfo indicates info level severity of a trace event.
CtInfo
// CtWarning indicates warning level severity of a trace event.
CtWarning
// CtError indicates error level severity of a trace event.
CtError
)
// RefChannelType is the type of the entity being referenced in a trace event.
type RefChannelType int
const (
// RefUnknown indicates an unknown entity type, the zero value for this type.
RefUnknown RefChannelType = iota
// RefChannel indicates the referenced entity is a Channel.
RefChannel
// RefSubChannel indicates the referenced entity is a SubChannel.
RefSubChannel
// RefServer indicates the referenced entity is a Server.
RefServer
// RefListenSocket indicates the referenced entity is a ListenSocket.
RefListenSocket
// RefNormalSocket indicates the referenced entity is a NormalSocket.
RefNormalSocket
)
var refChannelTypeToString = map[RefChannelType]string{
RefUnknown: "Unknown",
RefChannel: "Channel",
RefSubChannel: "SubChannel",
RefServer: "Server",
RefListenSocket: "ListenSocket",
RefNormalSocket: "NormalSocket",
}
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}
// AddTraceEvent adds trace related to the entity with specified id, using the
// provided TraceEventDesc.
//
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, e Entity, depth int, desc *TraceEvent) {
// Log only the trace description associated with the bottom most entity.
d := fmt.Sprintf("[%s]%s", e, desc.Desc)
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, d)
case CtWarning:
l.WarningDepth(depth+1, d)
case CtError:
l.ErrorDepth(depth+1, d)
}
if getMaxTraceEntry() == 0 {
return
}
if IsOn() {
db.traceEvent(e.id(), desc)
}
}

View File

@@ -1,727 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"net"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
)
// entry represents a node in the channelz database.
type entry interface {
// addChild adds a child e, whose channelz id is id to child list
addChild(id int64, e entry)
// deleteChild deletes a child with channelz id to be id from child list
deleteChild(id int64)
// triggerDelete tries to delete self from channelz database. However, if child
// list is not empty, then deletion from the database is on hold until the last
// child is deleted from database.
triggerDelete()
// deleteSelfIfReady check whether triggerDelete() has been called before, and whether child
// list is now empty. If both conditions are met, then delete self from database.
deleteSelfIfReady()
// getParentID returns parent ID of the entry. 0 value parent ID means no parent.
getParentID() int64
}
// dummyEntry is a fake entry to handle entry not found case.
type dummyEntry struct {
idNotFound int64
}
func (d *dummyEntry) addChild(id int64, e entry) {
// Note: It is possible for a normal program to reach here under race condition.
// For example, there could be a race between ClientConn.Close() info being propagated
// to addrConn and http2Client. ClientConn.Close() cancel the context and result
// in http2Client to error. The error info is then caught by transport monitor
// and before addrConn.tearDown() is called in side ClientConn.Close(). Therefore,
// the addrConn will create a new transport. And when registering the new transport in
// channelz, its parent addrConn could have already been torn down and deleted
// from channelz tracking, and thus reach the code here.
logger.Infof("attempt to add child of type %T with id %d to a parent (id=%d) that doesn't currently exist", e, id, d.idNotFound)
}
func (d *dummyEntry) deleteChild(id int64) {
// It is possible for a normal program to reach here under race condition.
// Refer to the example described in addChild().
logger.Infof("attempt to delete child with id %d from a parent (id=%d) that doesn't currently exist", id, d.idNotFound)
}
func (d *dummyEntry) triggerDelete() {
logger.Warningf("attempt to delete an entry (id=%d) that doesn't currently exist", d.idNotFound)
}
func (*dummyEntry) deleteSelfIfReady() {
// code should not reach here. deleteSelfIfReady is always called on an existing entry.
}
func (*dummyEntry) getParentID() int64 {
return 0
}
// ChannelMetric defines the info channelz provides for a specific Channel, which
// includes ChannelInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ChannelMetric struct {
// ID is the channelz id of this channel.
ID int64
// RefName is the human readable reference string of this channel.
RefName string
// ChannelData contains channel internal metric reported by the channel through
// ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this channel in the format of
// a map from nested channel channelz id to corresponding reference string.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this channel in the format of a
// map from subchannel channelz id to corresponding reference string.
SubChans map[int64]string
// Sockets tracks the socket type children of this channel in the format of a map
// from socket channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow channel having sockets directly,
// therefore, this is field is unused.
Sockets map[int64]string
// Trace contains the most recent traced events.
Trace *ChannelTrace
}
// SubChannelMetric defines the info channelz provides for a specific SubChannel,
// which includes ChannelInternalMetric and channelz-specific data, such as
// channelz id, child list, etc.
type SubChannelMetric struct {
// ID is the channelz id of this subchannel.
ID int64
// RefName is the human readable reference string of this subchannel.
RefName string
// ChannelData contains subchannel internal metric reported by the subchannel
// through ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this subchannel in the format of
// a map from nested channel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have nested channels
// as children, therefore, this field is unused.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this subchannel in the format of a
// map from subchannel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have subchannels
// as children, therefore, this field is unused.
SubChans map[int64]string
// Sockets tracks the socket type children of this subchannel in the format of a map
// from socket channelz id to corresponding reference string.
Sockets map[int64]string
// Trace contains the most recent traced events.
Trace *ChannelTrace
}
// ChannelInternalMetric defines the struct that the implementor of Channel interface
// should return from ChannelzMetric().
type ChannelInternalMetric struct {
// current connectivity state of the channel.
State connectivity.State
// The target this channel originally tried to connect to. May be absent
Target string
// The number of calls started on the channel.
CallsStarted int64
// The number of calls that have completed with an OK status.
CallsSucceeded int64
// The number of calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the channel.
LastCallStartedTimestamp time.Time
}
// ChannelTrace stores traced events on a channel/subchannel and related info.
type ChannelTrace struct {
// EventNum is the number of events that ever got traced (i.e. including those that have been deleted)
EventNum int64
// CreationTime is the creation time of the trace.
CreationTime time.Time
// Events stores the most recent trace events (up to $maxTraceEntry, newer event will overwrite the
// oldest one)
Events []*TraceEvent
}
// TraceEvent represent a single trace event
type TraceEvent struct {
// Desc is a simple description of the trace event.
Desc string
// Severity states the severity of this trace event.
Severity Severity
// Timestamp is the event time.
Timestamp time.Time
// RefID is the id of the entity that gets referenced in the event. RefID is 0 if no other entity is
// involved in this event.
// e.g. SubChannel (id: 4[]) Created. --> RefID = 4, RefName = "" (inside [])
RefID int64
// RefName is the reference name for the entity that gets referenced in the event.
RefName string
// RefType indicates the referenced entity type, i.e Channel or SubChannel.
RefType RefChannelType
}
// Channel is the interface that should be satisfied in order to be tracked by
// channelz as Channel or SubChannel.
type Channel interface {
ChannelzMetric() *ChannelInternalMetric
}
type dummyChannel struct{}
func (d *dummyChannel) ChannelzMetric() *ChannelInternalMetric {
return &ChannelInternalMetric{}
}
type channel struct {
refName string
c Channel
closeCalled bool
nestedChans map[int64]string
subChans map[int64]string
id int64
pid int64
cm *channelMap
trace *channelTrace
// traceRefCount is the number of trace events that reference this channel.
// Non-zero traceRefCount means the trace of this channel cannot be deleted.
traceRefCount int32
}
func (c *channel) addChild(id int64, e entry) {
switch v := e.(type) {
case *subChannel:
c.subChans[id] = v.refName
case *channel:
c.nestedChans[id] = v.refName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a channel", id, e)
}
}
func (c *channel) deleteChild(id int64) {
delete(c.subChans, id)
delete(c.nestedChans, id)
c.deleteSelfIfReady()
}
func (c *channel) triggerDelete() {
c.closeCalled = true
c.deleteSelfIfReady()
}
func (c *channel) getParentID() int64 {
return c.pid
}
// deleteSelfFromTree tries to delete the channel from the channelz entry relation tree, which means
// deleting the channel reference from its parent's child list.
//
// In order for a channel to be deleted from the tree, it must meet the criteria that, removal of the
// corresponding grpc object has been invoked, and the channel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (c *channel) deleteSelfFromTree() (deleted bool) {
if !c.closeCalled || len(c.subChans)+len(c.nestedChans) != 0 {
return false
}
// not top channel
if c.pid != 0 {
c.cm.findEntry(c.pid).deleteChild(c.id)
}
return true
}
// deleteSelfFromMap checks whether it is valid to delete the channel from the map, which means
// deleting the channel from channelz's tracking entirely. Users can no longer use id to query the
// channel, and its memory will be garbage collected.
//
// The trace reference count of the channel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (c *channel) deleteSelfFromMap() (delete bool) {
if c.getTraceRefCount() != 0 {
c.c = &dummyChannel{}
return false
}
return true
}
// deleteSelfIfReady tries to delete the channel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the channel from the entry relation tree, i.e. delete the channel reference from its
// parent's child list.
// 2. delete the channel from the map, i.e. delete the channel entirely from channelz. Lookup by id
// will return entry not found error.
func (c *channel) deleteSelfIfReady() {
if !c.deleteSelfFromTree() {
return
}
if !c.deleteSelfFromMap() {
return
}
c.cm.deleteEntry(c.id)
c.trace.clear()
}
func (c *channel) getChannelTrace() *channelTrace {
return c.trace
}
func (c *channel) incrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, 1)
}
func (c *channel) decrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, -1)
}
func (c *channel) getTraceRefCount() int {
i := atomic.LoadInt32(&c.traceRefCount)
return int(i)
}
func (c *channel) getRefName() string {
return c.refName
}
type subChannel struct {
refName string
c Channel
closeCalled bool
sockets map[int64]string
id int64
pid int64
cm *channelMap
trace *channelTrace
traceRefCount int32
}
func (sc *subChannel) addChild(id int64, e entry) {
if v, ok := e.(*normalSocket); ok {
sc.sockets[id] = v.refName
} else {
logger.Errorf("cannot add a child (id = %d) of type %T to a subChannel", id, e)
}
}
func (sc *subChannel) deleteChild(id int64) {
delete(sc.sockets, id)
sc.deleteSelfIfReady()
}
func (sc *subChannel) triggerDelete() {
sc.closeCalled = true
sc.deleteSelfIfReady()
}
func (sc *subChannel) getParentID() int64 {
return sc.pid
}
// deleteSelfFromTree tries to delete the subchannel from the channelz entry relation tree, which
// means deleting the subchannel reference from its parent's child list.
//
// In order for a subchannel to be deleted from the tree, it must meet the criteria that, removal of
// the corresponding grpc object has been invoked, and the subchannel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (sc *subChannel) deleteSelfFromTree() (deleted bool) {
if !sc.closeCalled || len(sc.sockets) != 0 {
return false
}
sc.cm.findEntry(sc.pid).deleteChild(sc.id)
return true
}
// deleteSelfFromMap checks whether it is valid to delete the subchannel from the map, which means
// deleting the subchannel from channelz's tracking entirely. Users can no longer use id to query
// the subchannel, and its memory will be garbage collected.
//
// The trace reference count of the subchannel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (sc *subChannel) deleteSelfFromMap() (delete bool) {
if sc.getTraceRefCount() != 0 {
// free the grpc struct (i.e. addrConn)
sc.c = &dummyChannel{}
return false
}
return true
}
// deleteSelfIfReady tries to delete the subchannel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the subchannel from the entry relation tree, i.e. delete the subchannel reference from
// its parent's child list.
// 2. delete the subchannel from the map, i.e. delete the subchannel entirely from channelz. Lookup
// by id will return entry not found error.
func (sc *subChannel) deleteSelfIfReady() {
if !sc.deleteSelfFromTree() {
return
}
if !sc.deleteSelfFromMap() {
return
}
sc.cm.deleteEntry(sc.id)
sc.trace.clear()
}
func (sc *subChannel) getChannelTrace() *channelTrace {
return sc.trace
}
func (sc *subChannel) incrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, 1)
}
func (sc *subChannel) decrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, -1)
}
func (sc *subChannel) getTraceRefCount() int {
i := atomic.LoadInt32(&sc.traceRefCount)
return int(i)
}
func (sc *subChannel) getRefName() string {
return sc.refName
}
// SocketMetric defines the info channelz provides for a specific Socket, which
// includes SocketInternalMetric and channelz-specific data, such as channelz id, etc.
type SocketMetric struct {
// ID is the channelz id of this socket.
ID int64
// RefName is the human readable reference string of this socket.
RefName string
// SocketData contains socket internal metric reported by the socket through
// ChannelzMetric().
SocketData *SocketInternalMetric
}
// SocketInternalMetric defines the struct that the implementor of Socket interface
// should return from ChannelzMetric().
type SocketInternalMetric struct {
// The number of streams that have been started.
StreamsStarted int64
// The number of streams that have ended successfully:
// On client side, receiving frame with eos bit set.
// On server side, sending frame with eos bit set.
StreamsSucceeded int64
// The number of streams that have ended unsuccessfully:
// On client side, termination without receiving frame with eos bit set.
// On server side, termination without sending frame with eos bit set.
StreamsFailed int64
// The number of messages successfully sent on this socket.
MessagesSent int64
MessagesReceived int64
// The number of keep alives sent. This is typically implemented with HTTP/2
// ping messages.
KeepAlivesSent int64
// The last time a stream was created by this endpoint. Usually unset for
// servers.
LastLocalStreamCreatedTimestamp time.Time
// The last time a stream was created by the remote endpoint. Usually unset
// for clients.
LastRemoteStreamCreatedTimestamp time.Time
// The last time a message was sent by this endpoint.
LastMessageSentTimestamp time.Time
// The last time a message was received by this endpoint.
LastMessageReceivedTimestamp time.Time
// The amount of window, granted to the local endpoint by the remote endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
LocalFlowControlWindow int64
// The amount of window, granted to the remote endpoint by the local endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
RemoteFlowControlWindow int64
// The locally bound address.
LocalAddr net.Addr
// The remote bound address. May be absent.
RemoteAddr net.Addr
// Optional, represents the name of the remote endpoint, if different than
// the original target name.
RemoteName string
SocketOptions *SocketOptionData
Security credentials.ChannelzSecurityValue
}
// Socket is the interface that should be satisfied in order to be tracked by
// channelz as Socket.
type Socket interface {
ChannelzMetric() *SocketInternalMetric
}
type listenSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ls *listenSocket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a listen socket", id, e)
}
func (ls *listenSocket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a listen socket", id)
}
func (ls *listenSocket) triggerDelete() {
ls.cm.deleteEntry(ls.id)
ls.cm.findEntry(ls.pid).deleteChild(ls.id)
}
func (ls *listenSocket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a listen socket")
}
func (ls *listenSocket) getParentID() int64 {
return ls.pid
}
type normalSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ns *normalSocket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a normal socket", id, e)
}
func (ns *normalSocket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a normal socket", id)
}
func (ns *normalSocket) triggerDelete() {
ns.cm.deleteEntry(ns.id)
ns.cm.findEntry(ns.pid).deleteChild(ns.id)
}
func (ns *normalSocket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a normal socket")
}
func (ns *normalSocket) getParentID() int64 {
return ns.pid
}
// ServerMetric defines the info channelz provides for a specific Server, which
// includes ServerInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ServerMetric struct {
// ID is the channelz id of this server.
ID int64
// RefName is the human readable reference string of this server.
RefName string
// ServerData contains server internal metric reported by the server through
// ChannelzMetric().
ServerData *ServerInternalMetric
// ListenSockets tracks the listener socket type children of this server in the
// format of a map from socket channelz id to corresponding reference string.
ListenSockets map[int64]string
}
// ServerInternalMetric defines the struct that the implementor of Server interface
// should return from ChannelzMetric().
type ServerInternalMetric struct {
// The number of incoming calls started on the server.
CallsStarted int64
// The number of incoming calls that have completed with an OK status.
CallsSucceeded int64
// The number of incoming calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the server.
LastCallStartedTimestamp time.Time
}
// Server is the interface to be satisfied in order to be tracked by channelz as
// Server.
type Server interface {
ChannelzMetric() *ServerInternalMetric
}
type server struct {
refName string
s Server
closeCalled bool
sockets map[int64]string
listenSockets map[int64]string
id int64
cm *channelMap
}
func (s *server) addChild(id int64, e entry) {
switch v := e.(type) {
case *normalSocket:
s.sockets[id] = v.refName
case *listenSocket:
s.listenSockets[id] = v.refName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a server", id, e)
}
}
func (s *server) deleteChild(id int64) {
delete(s.sockets, id)
delete(s.listenSockets, id)
s.deleteSelfIfReady()
}
func (s *server) triggerDelete() {
s.closeCalled = true
s.deleteSelfIfReady()
}
func (s *server) deleteSelfIfReady() {
if !s.closeCalled || len(s.sockets)+len(s.listenSockets) != 0 {
return
}
s.cm.deleteEntry(s.id)
}
func (s *server) getParentID() int64 {
return 0
}
type tracedChannel interface {
getChannelTrace() *channelTrace
incrTraceRefCount()
decrTraceRefCount()
getRefName() string
}
type channelTrace struct {
cm *channelMap
clearCalled bool
createdTime time.Time
eventCount int64
mu sync.Mutex
events []*TraceEvent
}
func (c *channelTrace) append(e *TraceEvent) {
c.mu.Lock()
if len(c.events) == getMaxTraceEntry() {
del := c.events[0]
c.events = c.events[1:]
if del.RefID != 0 {
// start recursive cleanup in a goroutine to not block the call originated from grpc.
go func() {
// need to acquire c.cm.mu lock to call the unlocked attemptCleanup func.
c.cm.mu.Lock()
c.cm.decrTraceRefCount(del.RefID)
c.cm.mu.Unlock()
}()
}
}
e.Timestamp = time.Now()
c.events = append(c.events, e)
c.eventCount++
c.mu.Unlock()
}
func (c *channelTrace) clear() {
if c.clearCalled {
return
}
c.clearCalled = true
c.mu.Lock()
for _, e := range c.events {
if e.RefID != 0 {
// caller should have already held the c.cm.mu lock.
c.cm.decrTraceRefCount(e.RefID)
}
}
c.mu.Unlock()
}
// Severity is the severity level of a trace event.
// The canonical enumeration of all valid values is here:
// https://github.com/grpc/grpc-proto/blob/9b13d199cc0d4703c7ea26c9c330ba695866eb23/grpc/channelz/v1/channelz.proto#L126.
type Severity int
const (
// CtUnknown indicates unknown severity of a trace event.
CtUnknown Severity = iota
// CtInfo indicates info level severity of a trace event.
CtInfo
// CtWarning indicates warning level severity of a trace event.
CtWarning
// CtError indicates error level severity of a trace event.
CtError
)
// RefChannelType is the type of the entity being referenced in a trace event.
type RefChannelType int
const (
// RefUnknown indicates an unknown entity type, the zero value for this type.
RefUnknown RefChannelType = iota
// RefChannel indicates the referenced entity is a Channel.
RefChannel
// RefSubChannel indicates the referenced entity is a SubChannel.
RefSubChannel
// RefServer indicates the referenced entity is a Server.
RefServer
// RefListenSocket indicates the referenced entity is a ListenSocket.
RefListenSocket
// RefNormalSocket indicates the referenced entity is a NormalSocket.
RefNormalSocket
)
var refChannelTypeToString = map[RefChannelType]string{
RefUnknown: "Unknown",
RefChannel: "Channel",
RefSubChannel: "SubChannel",
RefServer: "Server",
RefListenSocket: "ListenSocket",
RefNormalSocket: "NormalSocket",
}
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}
func (c *channelTrace) dumpData() *ChannelTrace {
c.mu.Lock()
ct := &ChannelTrace{EventNum: c.eventCount, CreationTime: c.createdTime}
ct.Events = c.events[:len(c.events)]
c.mu.Unlock()
return ct
}

View File

@@ -1,37 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"syscall"
)
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(socket any) *SocketOptionData {
c, ok := socket.(syscall.Conn)
if !ok {
return nil
}
data := &SocketOptionData{}
if rawConn, err := c.SyscallConn(); err == nil {
rawConn.Control(data.Getsockopt)
return data
}
return nil
}

View File

@@ -1,27 +0,0 @@
//go:build !linux
// +build !linux
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(c any) *SocketOptionData {
return nil
}

View File

@@ -28,9 +28,6 @@ import (
var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true)
// AdvertiseCompressors is set if registered compressor should be advertised
// ("GRPC_GO_ADVERTISE_COMPRESSORS" is not "false").
AdvertiseCompressors = boolFromEnv("GRPC_GO_ADVERTISE_COMPRESSORS", true)
// RingHashCap indicates the maximum ring size which defaults to 4096
// entries but may be overridden by setting the environment variable
// "GRPC_RING_HASH_CAP". This does not override the default bounds
@@ -43,6 +40,12 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
)
func boolFromEnv(envVar string, def bool) bool {

View File

@@ -1,100 +0,0 @@
//go:build !go1.21
// TODO: when this file is deleted (after Go 1.20 support is dropped), delete
// all of grpcrand and call the rand package directly.
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import (
"math/rand"
"sync"
"time"
)
var (
r = rand.New(rand.NewSource(time.Now().UnixNano()))
mu sync.Mutex
)
// Int implements rand.Int on the grpcrand global source.
func Int() int {
mu.Lock()
defer mu.Unlock()
return r.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
mu.Lock()
defer mu.Unlock()
return r.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
mu.Lock()
defer mu.Unlock()
return r.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
mu.Lock()
defer mu.Unlock()
return r.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
mu.Lock()
defer mu.Unlock()
return r.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
mu.Lock()
defer mu.Unlock()
return r.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
mu.Lock()
defer mu.Unlock()
return r.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
mu.Lock()
defer mu.Unlock()
return r.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
mu.Lock()
defer mu.Unlock()
r.Shuffle(n, f)
}

View File

@@ -1,73 +0,0 @@
//go:build go1.21
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import "math/rand"
// This implementation will be used for Go version 1.21 or newer.
// For older versions, the original implementation with mutex will be used.
// Int implements rand.Int on the grpcrand global source.
func Int() int {
return rand.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
return rand.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
return rand.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
return rand.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
return rand.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
return rand.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
return rand.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
return rand.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
rand.Shuffle(n, f)
}

View File

@@ -20,8 +20,6 @@ package grpcutil
import (
"strings"
"google.golang.org/grpc/internal/envconfig"
)
// RegisteredCompressorNames holds names of the registered compressors.
@@ -40,8 +38,5 @@ func IsCompressorNameRegistered(name string) bool {
// RegisteredCompressors returns a string of registered compressor names
// separated by comma.
func RegisteredCompressors() string {
if !envconfig.AdvertiseCompressors {
return ""
}
return strings.Join(RegisteredCompressorNames, ",")
}

View File

@@ -106,6 +106,14 @@ var (
// This is used in the 1.0 release of gcp/observability, and thus must not be
// deleted or changed.
ClearGlobalDialOptions func()
// AddGlobalPerTargetDialOptions adds a PerTargetDialOption that will be
// configured for newly created ClientConns.
AddGlobalPerTargetDialOptions any // func (opt any)
// ClearGlobalPerTargetDialOptions clears the slice of global late apply
// dial options.
ClearGlobalPerTargetDialOptions func()
// JoinDialOptions combines the dial options passed as arguments into a
// single dial option.
JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption
@@ -126,7 +134,8 @@ var (
// deleted or changed.
BinaryLogger any // func(binarylog.Logger) grpc.ServerOption
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a provided grpc.ClientConn
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a
// provided grpc.ClientConn.
SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber)
// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using
@@ -184,21 +193,25 @@ var (
ChannelzTurnOffForTesting func()
// TriggerXDSResourceNameNotFoundForTesting triggers the resource-not-found
// error for a given resource type and name. This is usually triggered when
// the associated watch timer fires. For testing purposes, having this
// function makes events more predictable than relying on timer events.
TriggerXDSResourceNameNotFoundForTesting any // func(func(xdsresource.Type, string), string, string) error
// TriggerXDSResourceNotFoundForTesting causes the provided xDS Client to
// invoke resource-not-found error for the given resource type and name.
TriggerXDSResourceNotFoundForTesting any // func(xdsclient.XDSClient, xdsresource.Type, string) error
// TriggerXDSResourceNotFoundClient invokes the testing xDS Client singleton
// to invoke resource not found for a resource type name and resource name.
TriggerXDSResourceNameNotFoundClient any // func(string, string) error
// FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD.
// FromOutgoingContextRaw returns the un-merged, intermediary contents of
// metadata.rawMD.
FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool)
// UserSetDefaultScheme is set to true if the user has overridden the
// default resolver scheme.
UserSetDefaultScheme bool = false
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j.
ShuffleAddressListForTesting any // func(n int, swap func(i, j int))
)
// HealthChecker defines the signature of the client-side LB channel health checking function.
// HealthChecker defines the signature of the client-side LB channel health
// checking function.
//
// The implementation is expected to create a health checking RPC stream by
// calling newStream(), watch for the health status of serviceName, and report

View File

@@ -24,9 +24,8 @@ import (
"encoding/json"
"fmt"
protov1 "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/encoding/protojson"
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
)
const jsonIndent = " "
@@ -35,21 +34,14 @@ const jsonIndent = " "
//
// If marshal fails, it falls back to fmt.Sprintf("%+v").
func ToJSON(e any) string {
switch ee := e.(type) {
case protov1.Message:
mm := protojson.MarshalOptions{Indent: jsonIndent}
ret, err := mm.Marshal(protov1.MessageV2(ee))
if err != nil {
// This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2
// messages are not imported, and this will fail because the message
// is not found.
return fmt.Sprintf("%+v", ee)
}
return string(ret)
case protov2.Message:
if ee, ok := e.(protoadapt.MessageV1); ok {
e = protoadapt.MessageV2Of(ee)
}
if ee, ok := e.(protoadapt.MessageV2); ok {
mm := protojson.MarshalOptions{
Multiline: true,
Indent: jsonIndent,
Multiline: true,
}
ret, err := mm.Marshal(ee)
if err != nil {
@@ -59,13 +51,13 @@ func ToJSON(e any) string {
return fmt.Sprintf("%+v", ee)
}
return string(ret)
default:
ret, err := json.MarshalIndent(ee, "", jsonIndent)
if err != nil {
return fmt.Sprintf("%+v", ee)
}
return string(ret)
}
ret, err := json.MarshalIndent(e, "", jsonIndent)
if err != nil {
return fmt.Sprintf("%+v", e)
}
return string(ret)
}
// FormatJSON formats the input json bytes with indentation.

View File

@@ -24,6 +24,7 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net"
"os"
"strconv"
@@ -35,21 +36,35 @@ import (
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/resolver/dns/internal"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
// addresses from SRV records. Must not be changed after init time.
var EnableSRVLookups = false
var (
// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
// addresses from SRV records. Must not be changed after init time.
EnableSRVLookups = false
var logger = grpclog.Component("dns")
// MinResolutionInterval is the minimum interval at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionInterval = 30 * time.Second
// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
// If the timeout expires before a response is received, the request will be canceled.
//
// It is recommended to set this value at application startup. Avoid modifying this variable
// after initialization as it's not thread-safe for concurrent modification.
ResolvingTimeout = 30 * time.Second
logger = grpclog.Component("dns")
)
func init() {
resolver.Register(NewBuilder())
internal.TimeAfterFunc = time.After
internal.TimeNowFunc = time.Now
internal.TimeUntilFunc = time.Until
internal.NewNetResolver = newNetResolver
internal.AddressDialer = addressDialer
}
@@ -196,12 +211,12 @@ func (d *dnsResolver) watcher() {
err = d.cc.UpdateState(*state)
}
var waitTime time.Duration
var nextResolutionTime time.Time
if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1
waitTime = internal.MinResolutionRate
nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
select {
case <-d.ctx.Done():
return
@@ -210,29 +225,29 @@ func (d *dnsResolver) watcher() {
} else {
// Poll on an error found in DNS Resolver or an error received from
// ClientConn.
waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
backoffIndex++
}
select {
case <-d.ctx.Done():
return
case <-internal.TimeAfterFunc(waitTime):
case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
}
}
}
func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
if !EnableSRVLookups {
return nil, nil
}
var newAddrs []resolver.Address
_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
if err != nil {
err = handleDNSError(err, "SRV") // may become nil
return nil, err
}
for _, s := range srvs {
lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
if err != nil {
err = handleDNSError(err, "A") // may become nil
if err == nil {
@@ -269,8 +284,8 @@ func handleDNSError(err error, lookupType string) error {
return err
}
func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
if err != nil {
if envconfig.TXTErrIgnore {
return nil
@@ -297,8 +312,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
return d.cc.ParseServiceConfig(sc)
}
func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(d.ctx, d.host)
func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(ctx, d.host)
if err != nil {
err = handleDNSError(err, "A")
return nil, err
@@ -316,8 +331,10 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
}
func (d *dnsResolver) lookup() (*resolver.State, error) {
srv, srvErr := d.lookupSRV()
addrs, hostErr := d.lookupHost()
ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
defer cancel()
srv, srvErr := d.lookupSRV(ctx)
addrs, hostErr := d.lookupHost(ctx)
if hostErr != nil && (srvErr != nil || len(srv) == 0) {
return nil, hostErr
}
@@ -327,7 +344,7 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
}
if !d.disableServiceConfig {
state.ServiceConfig = d.lookupTXT()
state.ServiceConfig = d.lookupTXT(ctx)
}
return &state, nil
}
@@ -408,7 +425,7 @@ func chosenByPercentage(a *int) bool {
if a == nil {
return true
}
return grpcrand.Intn(100)+1 <= *a
return rand.Intn(100)+1 <= *a
}
func canaryingSC(js string) string {

View File

@@ -28,7 +28,7 @@ import (
// NetResolver groups the methods on net.Resolver that are used by the DNS
// resolver implementation. This allows the default net.Resolver instance to be
// overidden from tests.
// overridden from tests.
type NetResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
@@ -50,16 +50,23 @@ var (
// The following vars are overridden from tests.
var (
// MinResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionRate = 30 * time.Second
// TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test
// to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is
// blocked waiting for the duration to elapse.
TimeAfterFunc func(time.Duration) <-chan time.Time
// TimeNowFunc is used by the DNS resolver to get the current time.
// In non-test code, this is implemented by time.Now. In test code,
// this can be used to control the current time for the resolver.
TimeNowFunc func() time.Time
// TimeUntilFunc is used by the DNS resolver to calculate the remaining
// wait time for re-resolution. In non-test code, this is implemented by
// time.Until. In test code, this can be used to control the remaining
// time for resolver to wait for re-resolution.
TimeUntilFunc func(time.Time) time.Duration
// NewNetResolver returns the net.Resolver instance for the given target.
NewNetResolver func(string) (NetResolver, error)

View File

@@ -193,7 +193,7 @@ type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn error // if set, loopyWriter will exit, resulting in conn closure
closeConn error // if set, loopyWriter will exit with this error
}
func (*goAway) isTransportResponseFrame() bool { return false }
@@ -336,7 +336,7 @@ func (c *controlBuffer) put(it cbItem) error {
return err
}
func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, error) {
func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) {
var wakeUp bool
c.mu.Lock()
if c.err != nil {
@@ -344,7 +344,7 @@ func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, err
return false, c.err
}
if f != nil {
if !f(it) { // f wasn't successful
if !f() { // f wasn't successful
c.mu.Unlock()
return false, nil
}
@@ -495,21 +495,22 @@ type loopyWriter struct {
ssGoAwayHandler func(*goAway) (bool, error)
}
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger) *loopyWriter {
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error)) *loopyWriter {
var buf bytes.Buffer
l := &loopyWriter{
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
ssGoAwayHandler: goAwayHandler,
}
return l
}

View File

@@ -51,14 +51,10 @@ import (
// inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
msg := "gRPC requires HTTP/2"
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if r.Method != "POST" {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
http.Error(w, msg, http.StatusBadRequest)
http.Error(w, msg, http.StatusMethodNotAllowed)
return nil, errors.New(msg)
}
contentType := r.Header.Get("Content-Type")
@@ -69,6 +65,11 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
http.Error(w, msg, http.StatusUnsupportedMediaType)
return nil, errors.New(msg)
}
if r.ProtoMajor != 2 {
msg := "gRPC requires HTTP/2"
http.Error(w, msg, http.StatusHTTPVersionNotSupported)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok {
msg := "gRPC requires a ResponseWriter supporting http.Flusher"
http.Error(w, msg, http.StatusInternalServerError)

View File

@@ -114,11 +114,11 @@ type http2Client struct {
streamQuota int64
streamsQuotaAvailable chan struct{}
waitingStreams uint32
nextID uint32
registeredCompressors string
// Do not access controlBuf with mu held.
mu sync.Mutex // guard the following variables
nextID uint32
state transportState
activeStreams map[uint32]*Stream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
@@ -140,9 +140,7 @@ type http2Client struct {
// variable.
kpDormant bool
// Fields below are for channelz metric collection.
channelzID *channelz.Identifier
czData *channelzData
channelz *channelz.Socket
onClose func(GoAwayReason)
@@ -319,6 +317,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if opts.MaxHeaderListSize != nil {
maxHeaderListSize = *opts.MaxHeaderListSize
}
t := &http2Client{
ctx: ctx,
ctxDone: ctx.Done(), // Cache Done chan.
@@ -346,11 +345,25 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
maxConcurrentStreams: defaultMaxStreamsClient,
streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1),
czData: new(channelzData),
keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
onClose: onClose,
}
var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
czSecurity = au.GetSecurityValue()
}
t.channelz = channelz.RegisterSocket(
&channelz.Socket{
SocketType: channelz.SocketTypeNormal,
Parent: opts.ChannelzParent,
SocketMetrics: channelz.SocketMetrics{},
EphemeralMetrics: t.socketMetrics,
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
SocketOptions: channelz.GetSocketOption(t.conn),
Security: czSecurity,
})
t.logger = prefixLoggerForClientTransport(t)
// Add peer information to the http2client context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())
@@ -381,10 +394,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr))
if err != nil {
return nil, err
}
if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu)
go t.keepalive()
@@ -399,10 +408,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerErrCh := make(chan error, 1)
go t.reader(readerErrCh)
defer func() {
if err == nil {
err = <-readerErrCh
}
if err != nil {
// writerDone should be closed since the loopy goroutine
// wouldn't have started in the case this function returns an error.
close(t.writerDone)
t.Close(err)
}
}()
@@ -449,8 +458,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err := t.framer.writer.Flush(); err != nil {
return nil, err
}
// Block until the server preface is received successfully or an error occurs.
if err = <-readerErrCh; err != nil {
return nil, err
}
go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
if err := t.loopy.run(); !isIOError(err) {
// Immediately close the connection, as the loopy writer returns
// when there are no more active streams and we were draining (the
@@ -508,6 +521,17 @@ func (t *http2Client) getPeer() *peer.Peer {
}
}
// OutgoingGoAwayHandler writes a GOAWAY to the connection. Always returns (false, err) as we want the GoAway
// to be the last frame loopy writes to the transport.
func (t *http2Client) outgoingGoAwayHandler(g *goAway) (bool, error) {
t.mu.Lock()
defer t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(t.nextID-2, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
return false, g.closeConn
}
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{
@@ -756,8 +780,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return ErrConnClosing
}
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano())
t.channelz.SocketMetrics.StreamsStarted.Add(1)
t.channelz.SocketMetrics.LastLocalStreamCreatedTimestamp.Store(time.Now().UnixNano())
}
// If the keepalive goroutine has gone dormant, wake it up.
if t.kpDormant {
@@ -772,7 +796,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
firstTry := true
var ch chan struct{}
transportDrainRequired := false
checkForStreamQuota := func(it any) bool {
checkForStreamQuota := func() bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
t.waitingStreams++
@@ -784,23 +808,24 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
t.waitingStreams--
}
t.streamQuota--
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
if t.state == draining || t.activeStreams == nil { // Can be niled from Close().
t.mu.Unlock()
return false // Don't create a stream if the transport is already closed.
}
hdr.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = hdr.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.activeStreams[s.id] = s
t.mu.Unlock()
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
case t.streamsQuotaAvailable <- struct{}{}:
@@ -810,13 +835,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true
}
var hdrListSizeErr error
checkForHeaderListSize := func(it any) bool {
checkForHeaderListSize := func() bool {
if t.maxSendHeaderListSize == nil {
return true
}
hdrFrame := it.(*headerFrame)
var sz int64
for _, f := range hdrFrame.hf {
for _, f := range hdr.hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize)
return false
@@ -825,8 +849,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true
}
for {
success, err := t.controlBuf.executeAndPut(func(it any) bool {
return checkForHeaderListSize(it) && checkForStreamQuota(it)
success, err := t.controlBuf.executeAndPut(func() bool {
return checkForHeaderListSize() && checkForStreamQuota()
}, hdr)
if err != nil {
// Connection closed.
@@ -928,16 +952,16 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
t.mu.Unlock()
if channelz.IsOn() {
if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1)
t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else {
atomic.AddInt64(&t.czData.streamsFailed, 1)
t.channelz.SocketMetrics.StreamsFailed.Add(1)
}
}
},
rst: rst,
rstCode: rstCode,
}
addBackStreamQuota := func(any) bool {
addBackStreamQuota := func() bool {
t.streamQuota++
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
@@ -957,7 +981,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be
// accessed any more.
// accessed anymore.
func (t *http2Client) Close(err error) {
t.mu.Lock()
// Make sure we only close once.
@@ -982,10 +1006,13 @@ func (t *http2Client) Close(err error) {
t.kpDormancyCond.Signal()
}
t.mu.Unlock()
t.controlBuf.finish()
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
<-t.writerDone
t.cancel()
t.conn.Close()
channelz.RemoveEntry(t.channelzID)
channelz.RemoveEntry(t.channelz.ID)
// Append info about previous goaways if there were any, since this may be important
// for understanding the root cause for this connection to be closed.
_, goAwayDebugMessage := t.GetGoAwayReason()
@@ -1090,7 +1117,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
// for the transport and the stream based on the current bdp
// estimation.
func (t *http2Client) updateFlowControl(n uint32) {
updateIWS := func(any) bool {
updateIWS := func() bool {
t.initialWindowSize = int32(n)
t.mu.Lock()
for _, s := range t.activeStreams {
@@ -1243,7 +1270,7 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
}
updateFuncs = append(updateFuncs, updateStreamQuota)
}
t.controlBuf.executeAndPut(func(any) bool {
t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs {
f()
}
@@ -1708,7 +1735,7 @@ func (t *http2Client) keepalive() {
// keepalive timer expired. In both cases, we need to send a ping.
if !outstandingPing {
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
t.channelz.SocketMetrics.KeepAlivesSent.Add(1)
}
t.controlBuf.put(p)
timeoutLeft = t.kp.Timeout
@@ -1738,40 +1765,23 @@ func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway
}
func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric {
s := channelz.SocketInternalMetric{
StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted),
StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded),
StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed),
MessagesSent: atomic.LoadInt64(&t.czData.msgSent),
MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv),
KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount),
LastLocalStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)),
LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)),
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
// RemoteName :
func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics {
return &channelz.EphemeralSocketMetrics{
LocalFlowControlWindow: int64(t.fc.getSize()),
RemoteFlowControlWindow: t.getOutFlowWindow(),
}
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
}
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) IncrMsgSent() {
atomic.AddInt64(&t.czData.msgSent, 1)
atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano())
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Store(time.Now().UnixNano())
}
func (t *http2Client) IncrMsgRecv() {
atomic.AddInt64(&t.czData.msgRecv, 1)
atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano())
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Store(time.Now().UnixNano())
}
func (t *http2Client) getOutFlowWindow() int64 {

View File

@@ -25,6 +25,7 @@ import (
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"strconv"
@@ -43,7 +44,6 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
@@ -118,8 +118,7 @@ type http2Server struct {
idle time.Time
// Fields below are for channelz metric collection.
channelzID *channelz.Identifier
czData *channelzData
channelz *channelz.Socket
bufferPool *bufferPool
connectionID uint64
@@ -262,9 +261,24 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
idle: time.Now(),
kep: kep,
initialWindowSize: iwz,
czData: new(channelzData),
bufferPool: newBufferPool(),
}
var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
czSecurity = au.GetSecurityValue()
}
t.channelz = channelz.RegisterSocket(
&channelz.Socket{
SocketType: channelz.SocketTypeNormal,
Parent: config.ChannelzParent,
SocketMetrics: channelz.SocketMetrics{},
EphemeralMetrics: t.socketMetrics,
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
SocketOptions: channelz.GetSocketOption(t.conn),
Security: czSecurity,
},
)
t.logger = prefixLoggerForServerTransport(t)
t.controlBuf = newControlBuffer(t.done)
@@ -274,10 +288,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1)
t.framer.writer.Flush()
@@ -320,8 +330,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.handleSettings(sf)
go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
err := t.loopy.run()
close(t.loopyWriterDone)
if !isIOError(err) {
@@ -334,9 +343,11 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// closed, would lead to a TCP RST instead of FIN, and the client
// encountering errors. For more info:
// https://github.com/grpc/grpc-go/issues/5358
timer := time.NewTimer(time.Second)
defer timer.Stop()
select {
case <-t.readerDone:
case <-time.After(time.Second):
case <-timer.C:
}
t.conn.Close()
}
@@ -592,8 +603,8 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
}
t.mu.Unlock()
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano())
t.channelz.SocketMetrics.StreamsStarted.Add(1)
t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano())
}
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
@@ -658,8 +669,14 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err)
break
// Any error processing client headers, e.g. invalid stream ID,
// is considered a protocol violation.
t.controlBuf.put(&goAway{
code: http2.ErrCodeProtocol,
debugData: []byte(err.Error()),
closeConn: err,
})
continue
}
case *http2.DataFrame:
t.handleData(frame)
@@ -842,7 +859,7 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
}
return nil
})
t.controlBuf.executeAndPut(func(any) bool {
t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs {
f()
}
@@ -996,12 +1013,13 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{
hf := &headerFrame{
streamID: s.id,
hf: headerFields,
endStream: false,
onWrite: t.setResetPingStrikes,
})
}
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf)
if !success {
if err != nil {
return err
@@ -1190,12 +1208,12 @@ func (t *http2Server) keepalive() {
continue
}
if outstandingPing && kpTimeoutLeft <= 0 {
t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Time))
t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Timeout))
return
}
if !outstandingPing {
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
t.channelz.SocketMetrics.KeepAlivesSent.Add(1)
}
t.controlBuf.put(p)
kpTimeoutLeft = t.kp.Timeout
@@ -1235,7 +1253,7 @@ func (t *http2Server) Close(err error) {
if err := t.conn.Close(); err != nil && t.logger.V(logLevel) {
t.logger.Infof("Error closing underlying net.Conn during Close: %v", err)
}
channelz.RemoveEntry(t.channelzID)
channelz.RemoveEntry(t.channelz.ID)
// Cancel all active streams.
for _, s := range streams {
s.cancel()
@@ -1256,9 +1274,9 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
if channelz.IsOn() {
if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1)
t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else {
atomic.AddInt64(&t.czData.streamsFailed, 1)
t.channelz.SocketMetrics.StreamsFailed.Add(1)
}
}
}
@@ -1375,38 +1393,21 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
return false, nil
}
func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
s := channelz.SocketInternalMetric{
StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted),
StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded),
StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed),
MessagesSent: atomic.LoadInt64(&t.czData.msgSent),
MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv),
KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount),
LastRemoteStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)),
LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)),
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
// RemoteName :
func (t *http2Server) socketMetrics() *channelz.EphemeralSocketMetrics {
return &channelz.EphemeralSocketMetrics{
LocalFlowControlWindow: int64(t.fc.getSize()),
RemoteFlowControlWindow: t.getOutFlowWindow(),
}
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
}
func (t *http2Server) IncrMsgSent() {
atomic.AddInt64(&t.czData.msgSent, 1)
atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano())
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Add(1)
}
func (t *http2Server) IncrMsgRecv() {
atomic.AddInt64(&t.czData.msgRecv, 1)
atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano())
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Add(1)
}
func (t *http2Server) getOutFlowWindow() int64 {
@@ -1439,7 +1440,7 @@ func getJitter(v time.Duration) time.Duration {
}
// Generate a jitter between +/- 10% of the value.
r := int64(v / 10)
j := grpcrand.Int63n(2*r) - r
j := rand.Int63n(2*r) - r
return time.Duration(j)
}

View File

@@ -418,10 +418,9 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
return f
}
func getWriteBufferPool(writeBufferSize int) *sync.Pool {
func getWriteBufferPool(size int) *sync.Pool {
writeBufferMutex.Lock()
defer writeBufferMutex.Unlock()
size := writeBufferSize * 2
pool, ok := writeBufferPoolMap[size]
if ok {
return pool

View File

@@ -28,6 +28,7 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -303,7 +304,7 @@ func (s *Stream) isHeaderSent() bool {
}
// updateHeaderSent updates headerSent and returns true
// if it was alreay set. It is valid only on server-side.
// if it was already set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
}
@@ -362,8 +363,12 @@ func (s *Stream) SendCompress() string {
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
func (s *Stream) ClientAdvertisedCompressors() []string {
values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
}
// Done returns a channel which is closed when it receives the final status
@@ -566,7 +571,7 @@ type ServerConfig struct {
WriteBufferSize int
ReadBufferSize int
SharedWriteBuffer bool
ChannelzParentID *channelz.Identifier
ChannelzParent *channelz.Server
MaxHeaderListSize *uint32
HeaderTableSize *uint32
}
@@ -601,8 +606,8 @@ type ConnectOptions struct {
ReadBufferSize int
// SharedWriteBuffer indicates whether connections should reuse write buffer
SharedWriteBuffer bool
// ChannelzParentID sets the addrConn id which initiate the creation of this client transport.
ChannelzParentID *channelz.Identifier
// ChannelzParent sets the addrConn id which initiated the creation of this client transport.
ChannelzParent *channelz.SubChannel
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used.
@@ -815,30 +820,6 @@ const (
GoAwayTooManyPings GoAwayReason = 2
)
// channelzData is used to store channelz related data for http2Client and http2Server.
// These fields cannot be embedded in the original structs (e.g. http2Client), since to do atomic
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment.
type channelzData struct {
kpCount int64
// The number of streams that have started, including already finished ones.
streamsStarted int64
// Client side: The number of streams that have ended successfully by receiving
// EoS bit set frame from server.
// Server side: The number of streams that have ended successfully by sending
// frame with EoS bit set.
streamsSucceeded int64
streamsFailed int64
// lastStreamCreatedTime stores the timestamp that the last stream gets created. It is of int64 type
// instead of time.Time since it's more costly to atomically update time.Time variable than int64
// variable. The same goes for lastMsgSentTime and lastMsgRecvTime.
lastStreamCreatedTime int64
msgSent int64
msgRecv int64
lastMsgSentTime int64
lastMsgRecvTime int64
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {

View File

@@ -1,40 +0,0 @@
/*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/resolver"
)
// handshakeClusterNameKey is the type used as the key to store cluster name in
// the Attributes field of resolver.Address.
type handshakeClusterNameKey struct{}
// SetXDSHandshakeClusterName returns a copy of addr in which the Attributes field
// is updated with the cluster name.
func SetXDSHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeClusterNameKey{}, clusterName)
return addr
}
// GetXDSHandshakeClusterName returns cluster name stored in attr.
func GetXDSHandshakeClusterName(attr *attributes.Attributes) (string, bool) {
v := attr.Value(handshakeClusterNameKey{})
name, ok := v.(string)
return name, ok
}

View File

@@ -22,7 +22,9 @@ package peer
import (
"context"
"fmt"
"net"
"strings"
"google.golang.org/grpc/credentials"
)
@@ -39,6 +41,34 @@ type Peer struct {
AuthInfo credentials.AuthInfo
}
// String ensures the Peer types implements the Stringer interface in order to
// allow to print a context with a peerKey value effectively.
func (p *Peer) String() string {
if p == nil {
return "Peer<nil>"
}
sb := &strings.Builder{}
sb.WriteString("Peer{")
if p.Addr != nil {
fmt.Fprintf(sb, "Addr: '%s', ", p.Addr.String())
} else {
fmt.Fprintf(sb, "Addr: <nil>, ")
}
if p.LocalAddr != nil {
fmt.Fprintf(sb, "LocalAddr: '%s', ", p.LocalAddr.String())
} else {
fmt.Fprintf(sb, "LocalAddr: <nil>, ")
}
if p.AuthInfo != nil {
fmt.Fprintf(sb, "AuthInfo: '%s'", p.AuthInfo.AuthType())
} else {
fmt.Fprintf(sb, "AuthInfo: <nil>")
}
sb.WriteString("}")
return sb.String()
}
type peerKey struct{}
// NewContext creates a new context with peer information attached.

View File

@@ -20,8 +20,9 @@ package grpc
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
@@ -32,35 +33,43 @@ import (
"google.golang.org/grpc/status"
)
// pickerGeneration stores a picker and a channel used to signal that a picker
// newer than this one is available.
type pickerGeneration struct {
// picker is the picker produced by the LB policy. May be nil if a picker
// has never been produced.
picker balancer.Picker
// blockingCh is closed when the picker has been invalidated because there
// is a new one available.
blockingCh chan struct{}
}
// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick
// actions and unblock when there's a picker update.
type pickerWrapper struct {
mu sync.Mutex
done bool
blockingCh chan struct{}
picker balancer.Picker
// If pickerGen holds a nil pointer, the pickerWrapper is closed.
pickerGen atomic.Pointer[pickerGeneration]
statsHandlers []stats.Handler // to record blocking picker calls
}
func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper {
return &pickerWrapper{
blockingCh: make(chan struct{}),
pw := &pickerWrapper{
statsHandlers: statsHandlers,
}
pw.pickerGen.Store(&pickerGeneration{
blockingCh: make(chan struct{}),
})
return pw
}
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick.
// updatePicker is called by UpdateState calls from the LB policy. It
// unblocks all blocked pick.
func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
return
}
pw.picker = p
// pw.blockingCh should never be nil.
close(pw.blockingCh)
pw.blockingCh = make(chan struct{})
pw.mu.Unlock()
old := pw.pickerGen.Swap(&pickerGeneration{
picker: p,
blockingCh: make(chan struct{}),
})
close(old.blockingCh)
}
// doneChannelzWrapper performs the following:
@@ -97,27 +106,24 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
var lastPickErr error
for {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
pg := pw.pickerGen.Load()
if pg == nil {
return nil, balancer.PickResult{}, ErrClientConnClosing
}
if pw.picker == nil {
ch = pw.blockingCh
if pg.picker == nil {
ch = pg.blockingCh
}
if ch == pw.blockingCh {
if ch == pg.blockingCh {
// This could happen when either:
// - pw.picker is nil (the previous if condition), or
// - has called pick on the current picker.
pw.mu.Unlock()
// - we have already called pick on the current picker.
select {
case <-ctx.Done():
var errStr string
if lastPickErr != nil {
errStr = "latest balancer error: " + lastPickErr.Error()
} else {
errStr = ctx.Err().Error()
errStr = fmt.Sprintf("received context error while waiting for new LB policy update: %s", ctx.Err().Error())
}
switch ctx.Err() {
case context.DeadlineExceeded:
@@ -144,9 +150,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
}
ch = pw.blockingCh
p := pw.picker
pw.mu.Unlock()
ch = pg.blockingCh
p := pg.picker
pickResult, err := p.Pick(info)
if err != nil {
@@ -196,24 +201,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
func (pw *pickerWrapper) close() {
pw.mu.Lock()
defer pw.mu.Unlock()
if pw.done {
return
}
pw.done = true
close(pw.blockingCh)
old := pw.pickerGen.Swap(nil)
close(old.blockingCh)
}
// reset clears the pickerWrapper and prepares it for being used again when idle
// mode is exited.
func (pw *pickerWrapper) reset() {
pw.mu.Lock()
defer pw.mu.Unlock()
if pw.done {
return
}
pw.blockingCh = make(chan struct{})
old := pw.pickerGen.Swap(&pickerGeneration{blockingCh: make(chan struct{})})
close(old.blockingCh)
}
// dropError is a wrapper error that indicates the LB policy wishes to drop the

View File

@@ -19,10 +19,11 @@
package reflection
import (
"google.golang.org/grpc/reflection/internal"
v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
v1alphareflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
// asV1Alpha returns an implementation of the v1alpha version of the reflection
@@ -44,7 +45,7 @@ type v1AlphaServerStreamAdapter struct {
}
func (s v1AlphaServerStreamAdapter) Send(response *v1reflectionpb.ServerReflectionResponse) error {
return s.ServerReflection_ServerReflectionInfoServer.Send(v1ToV1AlphaResponse(response))
return s.ServerReflection_ServerReflectionInfoServer.Send(internal.V1ToV1AlphaResponse(response))
}
func (s v1AlphaServerStreamAdapter) Recv() (*v1reflectionpb.ServerReflectionRequest, error) {
@@ -52,136 +53,5 @@ func (s v1AlphaServerStreamAdapter) Recv() (*v1reflectionpb.ServerReflectionRequ
if err != nil {
return nil, err
}
return v1AlphaToV1Request(resp), nil
}
func v1ToV1AlphaResponse(v1 *v1reflectionpb.ServerReflectionResponse) *v1alphareflectionpb.ServerReflectionResponse {
var v1alpha v1alphareflectionpb.ServerReflectionResponse
v1alpha.ValidHost = v1.ValidHost
if v1.OriginalRequest != nil {
v1alpha.OriginalRequest = v1ToV1AlphaRequest(v1.OriginalRequest)
}
switch mr := v1.MessageResponse.(type) {
case *v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1alphareflectionpb.FileDescriptorResponse{
FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
},
}
}
case *v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &v1alphareflectionpb.ExtensionNumberResponse{
BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(),
ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
},
}
}
case *v1reflectionpb.ServerReflectionResponse_ListServicesResponse:
if mr != nil {
svcs := make([]*v1alphareflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
for i, svc := range mr.ListServicesResponse.GetService() {
svcs[i] = &v1alphareflectionpb.ServiceResponse{
Name: svc.GetName(),
}
}
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &v1alphareflectionpb.ListServiceResponse{
Service: svcs,
},
}
}
case *v1reflectionpb.ServerReflectionResponse_ErrorResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1alphareflectionpb.ErrorResponse{
ErrorCode: mr.ErrorResponse.GetErrorCode(),
ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
},
}
}
default:
// no value set
}
return &v1alpha
}
func v1AlphaToV1Request(v1alpha *v1alphareflectionpb.ServerReflectionRequest) *v1reflectionpb.ServerReflectionRequest {
var v1 v1reflectionpb.ServerReflectionRequest
v1.Host = v1alpha.Host
switch mr := v1alpha.MessageRequest.(type) {
case *v1alphareflectionpb.ServerReflectionRequest_FileByFilename:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileByFilename{
FileByFilename: mr.FileByFilename,
}
case *v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: mr.FileContainingSymbol,
}
case *v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension:
if mr.FileContainingExtension != nil {
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &v1reflectionpb.ExtensionRequest{
ContainingType: mr.FileContainingExtension.GetContainingType(),
ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
},
}
}
case *v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
}
case *v1alphareflectionpb.ServerReflectionRequest_ListServices:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_ListServices{
ListServices: mr.ListServices,
}
default:
// no value set
}
return &v1
}
func v1ToV1AlphaRequest(v1 *v1reflectionpb.ServerReflectionRequest) *v1alphareflectionpb.ServerReflectionRequest {
var v1alpha v1alphareflectionpb.ServerReflectionRequest
v1alpha.Host = v1.Host
switch mr := v1.MessageRequest.(type) {
case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileByFilename{
FileByFilename: mr.FileByFilename,
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: mr.FileContainingSymbol,
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &v1alphareflectionpb.ExtensionRequest{
ContainingType: mr.FileContainingExtension.GetContainingType(),
ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
},
}
}
case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
}
}
case *v1reflectionpb.ServerReflectionRequest_ListServices:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_ListServices{
ListServices: mr.ListServices,
}
}
default:
// no value set
}
return &v1alpha
return internal.V1AlphaToV1Request(resp), nil
}

View File

@@ -21,7 +21,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc-gen-go v1.34.1
// protoc v4.25.2
// source: grpc/reflection/v1/reflection.proto

View File

@@ -21,7 +21,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc-gen-go-grpc v1.4.0
// - protoc v4.25.2
// source: grpc/reflection/v1/reflection.proto
@@ -36,8 +36,8 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.62.0 or later.
const _ = grpc.SupportPackageIsVersion8
const (
ServerReflection_ServerReflectionInfo_FullMethodName = "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo"
@@ -61,11 +61,12 @@ func NewServerReflectionClient(cc grpc.ClientConnInterface) ServerReflectionClie
}
func (c *serverReflectionClient) ServerReflectionInfo(ctx context.Context, opts ...grpc.CallOption) (ServerReflection_ServerReflectionInfoClient, error) {
stream, err := c.cc.NewStream(ctx, &ServerReflection_ServiceDesc.Streams[0], ServerReflection_ServerReflectionInfo_FullMethodName, opts...)
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &ServerReflection_ServiceDesc.Streams[0], ServerReflection_ServerReflectionInfo_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &serverReflectionServerReflectionInfoClient{stream}
x := &serverReflectionServerReflectionInfoClient{ClientStream: stream}
return x, nil
}
@@ -120,7 +121,7 @@ func RegisterServerReflectionServer(s grpc.ServiceRegistrar, srv ServerReflectio
}
func _ServerReflection_ServerReflectionInfo_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ServerReflectionServer).ServerReflectionInfo(&serverReflectionServerReflectionInfoServer{stream})
return srv.(ServerReflectionServer).ServerReflectionInfo(&serverReflectionServerReflectionInfoServer{ServerStream: stream})
}
type ServerReflection_ServerReflectionInfoServer interface {

View File

@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc-gen-go v1.34.1
// protoc v4.25.2
// grpc/reflection/v1alpha/reflection.proto is a deprecated file.

View File

@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc-gen-go-grpc v1.4.0
// - protoc v4.25.2
// grpc/reflection/v1alpha/reflection.proto is a deprecated file.
@@ -33,8 +33,8 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.62.0 or later.
const _ = grpc.SupportPackageIsVersion8
const (
ServerReflection_ServerReflectionInfo_FullMethodName = "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo"
@@ -58,11 +58,12 @@ func NewServerReflectionClient(cc grpc.ClientConnInterface) ServerReflectionClie
}
func (c *serverReflectionClient) ServerReflectionInfo(ctx context.Context, opts ...grpc.CallOption) (ServerReflection_ServerReflectionInfoClient, error) {
stream, err := c.cc.NewStream(ctx, &ServerReflection_ServiceDesc.Streams[0], ServerReflection_ServerReflectionInfo_FullMethodName, opts...)
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &ServerReflection_ServiceDesc.Streams[0], ServerReflection_ServerReflectionInfo_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &serverReflectionServerReflectionInfoClient{stream}
x := &serverReflectionServerReflectionInfoClient{ClientStream: stream}
return x, nil
}
@@ -117,7 +118,7 @@ func RegisterServerReflectionServer(s grpc.ServiceRegistrar, srv ServerReflectio
}
func _ServerReflection_ServerReflectionInfo_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ServerReflectionServer).ServerReflectionInfo(&serverReflectionServerReflectionInfoServer{stream})
return srv.(ServerReflectionServer).ServerReflectionInfo(&serverReflectionServerReflectionInfoServer{ServerStream: stream})
}
type ServerReflection_ServerReflectionInfoServer interface {

View File

@@ -0,0 +1,436 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains code that is shared by both reflection package and
// the test package. The packages are split in this way inorder to avoid
// depenedency to deprecated package github.com/golang/protobuf.
package internal
import (
"io"
"sort"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
v1alphareflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
// ServiceInfoProvider is an interface used to retrieve metadata about the
// services to expose.
type ServiceInfoProvider interface {
GetServiceInfo() map[string]grpc.ServiceInfo
}
// ExtensionResolver is the interface used to query details about extensions.
// This interface is satisfied by protoregistry.GlobalTypes.
type ExtensionResolver interface {
protoregistry.ExtensionTypeResolver
RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
}
// ServerReflectionServer is the server API for ServerReflection service.
type ServerReflectionServer struct {
v1alphareflectiongrpc.UnimplementedServerReflectionServer
S ServiceInfoProvider
DescResolver protodesc.Resolver
ExtResolver ExtensionResolver
}
// FileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func (s *ServerReflectionServer) FileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]bool) ([][]byte, error) {
if fd.IsPlaceholder() {
// If the given root file is a placeholder, treat it
// as missing instead of serializing it.
return nil, protoregistry.NotFound
}
var r [][]byte
queue := []protoreflect.FileDescriptor{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if currentfd.IsPlaceholder() {
// Skip any missing files in the dependency graph.
continue
}
if sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.Path()] = true
fdProto := protodesc.ToFileDescriptorProto(currentfd)
currentfdEncoded, err := proto.Marshal(fdProto)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for i := 0; i < currentfd.Imports().Len(); i++ {
queue = append(queue, currentfd.Imports().Get(i))
}
}
return r, nil
}
// FileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *ServerReflectionServer) FileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
d, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name))
if err != nil {
return nil, err
}
return s.FileDescWithDependencies(d.ParentFile(), sentFileDescriptors)
}
// FileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *ServerReflectionServer) FileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
xt, err := s.ExtResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum))
if err != nil {
return nil, err
}
return s.FileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors)
}
// AllExtensionNumbersForTypeName returns all extension numbers for the given type.
func (s *ServerReflectionServer) AllExtensionNumbersForTypeName(name string) ([]int32, error) {
var numbers []int32
s.ExtResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool {
numbers = append(numbers, int32(xt.TypeDescriptor().Number()))
return true
})
sort.Slice(numbers, func(i, j int) bool {
return numbers[i] < numbers[j]
})
if len(numbers) == 0 {
// maybe return an error if given type name is not known
if _, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil {
return nil, err
}
}
return numbers, nil
}
// ListServices returns the names of services this server exposes.
func (s *ServerReflectionServer) ListServices() []*v1reflectionpb.ServiceResponse {
serviceInfo := s.S.GetServiceInfo()
resp := make([]*v1reflectionpb.ServiceResponse, 0, len(serviceInfo))
for svc := range serviceInfo {
resp = append(resp, &v1reflectionpb.ServiceResponse{Name: svc})
}
sort.Slice(resp, func(i, j int) bool {
return resp[i].Name < resp[j].Name
})
return resp
}
// ServerReflectionInfo is the reflection service handler.
func (s *ServerReflectionServer) ServerReflectionInfo(stream v1reflectiongrpc.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := &v1reflectionpb.ServerReflectionResponse{
ValidHost: in.Host,
OriginalRequest: in,
}
switch req := in.MessageRequest.(type) {
case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
var b [][]byte
fd, err := s.DescResolver.FindFileByPath(req.FileByFilename)
if err == nil {
b, err = s.FileDescWithDependencies(fd, sentFileDescriptors)
}
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.FileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.FileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.AllExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *v1reflectionpb.ServerReflectionRequest_ListServices:
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &v1reflectionpb.ListServiceResponse{
Service: s.ListServices(),
},
}
default:
return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
}
if err := stream.Send(out); err != nil {
return err
}
}
}
// V1ToV1AlphaResponse converts a v1 ServerReflectionResponse to a v1alpha.
func V1ToV1AlphaResponse(v1 *v1reflectionpb.ServerReflectionResponse) *v1alphareflectionpb.ServerReflectionResponse {
var v1alpha v1alphareflectionpb.ServerReflectionResponse
v1alpha.ValidHost = v1.ValidHost
if v1.OriginalRequest != nil {
v1alpha.OriginalRequest = V1ToV1AlphaRequest(v1.OriginalRequest)
}
switch mr := v1.MessageResponse.(type) {
case *v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1alphareflectionpb.FileDescriptorResponse{
FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
},
}
}
case *v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &v1alphareflectionpb.ExtensionNumberResponse{
BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(),
ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
},
}
}
case *v1reflectionpb.ServerReflectionResponse_ListServicesResponse:
if mr != nil {
svcs := make([]*v1alphareflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
for i, svc := range mr.ListServicesResponse.GetService() {
svcs[i] = &v1alphareflectionpb.ServiceResponse{
Name: svc.GetName(),
}
}
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &v1alphareflectionpb.ListServiceResponse{
Service: svcs,
},
}
}
case *v1reflectionpb.ServerReflectionResponse_ErrorResponse:
if mr != nil {
v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1alphareflectionpb.ErrorResponse{
ErrorCode: mr.ErrorResponse.GetErrorCode(),
ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
},
}
}
default:
// no value set
}
return &v1alpha
}
// V1AlphaToV1Request converts a v1alpha ServerReflectionRequest to a v1.
func V1AlphaToV1Request(v1alpha *v1alphareflectionpb.ServerReflectionRequest) *v1reflectionpb.ServerReflectionRequest {
var v1 v1reflectionpb.ServerReflectionRequest
v1.Host = v1alpha.Host
switch mr := v1alpha.MessageRequest.(type) {
case *v1alphareflectionpb.ServerReflectionRequest_FileByFilename:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileByFilename{
FileByFilename: mr.FileByFilename,
}
case *v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: mr.FileContainingSymbol,
}
case *v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension:
if mr.FileContainingExtension != nil {
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &v1reflectionpb.ExtensionRequest{
ContainingType: mr.FileContainingExtension.GetContainingType(),
ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
},
}
}
case *v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
}
case *v1alphareflectionpb.ServerReflectionRequest_ListServices:
v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_ListServices{
ListServices: mr.ListServices,
}
default:
// no value set
}
return &v1
}
// V1ToV1AlphaRequest converts a v1 ServerReflectionRequest to a v1alpha.
func V1ToV1AlphaRequest(v1 *v1reflectionpb.ServerReflectionRequest) *v1alphareflectionpb.ServerReflectionRequest {
var v1alpha v1alphareflectionpb.ServerReflectionRequest
v1alpha.Host = v1.Host
switch mr := v1.MessageRequest.(type) {
case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileByFilename{
FileByFilename: mr.FileByFilename,
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: mr.FileContainingSymbol,
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &v1alphareflectionpb.ExtensionRequest{
ContainingType: mr.FileContainingExtension.GetContainingType(),
ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
},
}
}
case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
}
}
case *v1reflectionpb.ServerReflectionRequest_ListServices:
if mr != nil {
v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_ListServices{
ListServices: mr.ListServices,
}
}
default:
// no value set
}
return &v1alpha
}
// V1AlphaToV1Response converts a v1alpha ServerReflectionResponse to a v1.
func V1AlphaToV1Response(v1alpha *v1alphareflectionpb.ServerReflectionResponse) *v1reflectionpb.ServerReflectionResponse {
var v1 v1reflectionpb.ServerReflectionResponse
v1.ValidHost = v1alpha.ValidHost
if v1alpha.OriginalRequest != nil {
v1.OriginalRequest = V1AlphaToV1Request(v1alpha.OriginalRequest)
}
switch mr := v1alpha.MessageResponse.(type) {
case *v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse:
if mr != nil {
v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{
FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
},
}
}
case *v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
if mr != nil {
v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(),
ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
},
}
}
case *v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse:
if mr != nil {
svcs := make([]*v1reflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
for i, svc := range mr.ListServicesResponse.GetService() {
svcs[i] = &v1reflectionpb.ServiceResponse{
Name: svc.GetName(),
}
}
v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &v1reflectionpb.ListServiceResponse{
Service: svcs,
},
}
}
case *v1alphareflectionpb.ServerReflectionResponse_ErrorResponse:
if mr != nil {
v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: mr.ErrorResponse.GetErrorCode(),
ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
},
}
}
default:
// no value set
}
return &v1
}

View File

@@ -37,19 +37,13 @@ To register server reflection on a gRPC server:
package reflection // import "google.golang.org/grpc/reflection"
import (
"io"
"sort"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/reflection/internal"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
@@ -158,203 +152,9 @@ func NewServerV1(opts ServerOptions) v1reflectiongrpc.ServerReflectionServer {
if opts.ExtensionResolver == nil {
opts.ExtensionResolver = protoregistry.GlobalTypes
}
return &serverReflectionServer{
s: opts.Services,
descResolver: opts.DescriptorResolver,
extResolver: opts.ExtensionResolver,
}
}
type serverReflectionServer struct {
v1alphareflectiongrpc.UnimplementedServerReflectionServer
s ServiceInfoProvider
descResolver protodesc.Resolver
extResolver ExtensionResolver
}
// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func (s *serverReflectionServer) fileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]bool) ([][]byte, error) {
if fd.IsPlaceholder() {
// If the given root file is a placeholder, treat it
// as missing instead of serializing it.
return nil, protoregistry.NotFound
}
var r [][]byte
queue := []protoreflect.FileDescriptor{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if currentfd.IsPlaceholder() {
// Skip any missing files in the dependency graph.
continue
}
if sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.Path()] = true
fdProto := protodesc.ToFileDescriptorProto(currentfd)
currentfdEncoded, err := proto.Marshal(fdProto)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for i := 0; i < currentfd.Imports().Len(); i++ {
queue = append(queue, currentfd.Imports().Get(i))
}
}
return r, nil
}
// fileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
d, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name))
if err != nil {
return nil, err
}
return s.fileDescWithDependencies(d.ParentFile(), sentFileDescriptors)
}
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
xt, err := s.extResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum))
if err != nil {
return nil, err
}
return s.fileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors)
}
// allExtensionNumbersForTypeName returns all extension numbers for the given type.
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
var numbers []int32
s.extResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool {
numbers = append(numbers, int32(xt.TypeDescriptor().Number()))
return true
})
sort.Slice(numbers, func(i, j int) bool {
return numbers[i] < numbers[j]
})
if len(numbers) == 0 {
// maybe return an error if given type name is not known
if _, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil {
return nil, err
}
}
return numbers, nil
}
// listServices returns the names of services this server exposes.
func (s *serverReflectionServer) listServices() []*v1reflectionpb.ServiceResponse {
serviceInfo := s.s.GetServiceInfo()
resp := make([]*v1reflectionpb.ServiceResponse, 0, len(serviceInfo))
for svc := range serviceInfo {
resp = append(resp, &v1reflectionpb.ServiceResponse{Name: svc})
}
sort.Slice(resp, func(i, j int) bool {
return resp[i].Name < resp[j].Name
})
return resp
}
// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream v1reflectiongrpc.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := &v1reflectionpb.ServerReflectionResponse{
ValidHost: in.Host,
OriginalRequest: in,
}
switch req := in.MessageRequest.(type) {
case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
var b [][]byte
fd, err := s.descResolver.FindFileByPath(req.FileByFilename)
if err == nil {
b, err = s.fileDescWithDependencies(fd, sentFileDescriptors)
}
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &v1reflectionpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *v1reflectionpb.ServerReflectionRequest_ListServices:
out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &v1reflectionpb.ListServiceResponse{
Service: s.listServices(),
},
}
default:
return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
}
if err := stream.Send(out); err != nil {
return err
}
return &internal.ServerReflectionServer{
S: opts.Services,
DescResolver: opts.DescriptorResolver,
ExtResolver: opts.ExtensionResolver,
}
}

View File

@@ -63,7 +63,7 @@ LEGACY_SOURCES=(
# Generates only the new gRPC Service symbols
SOURCES=(
$(git ls-files --exclude-standard --cached --others "*.proto" | grep -v '^\(profiling/proto/service.proto\|reflection/grpc_reflection_v1alpha/reflection.proto\)$')
$(git ls-files --exclude-standard --cached --others "*.proto" | grep -v '^profiling/proto/service.proto$')
${WORKDIR}/grpc-proto/grpc/gcp/altscontext.proto
${WORKDIR}/grpc-proto/grpc/gcp/handshaker.proto
${WORKDIR}/grpc-proto/grpc/gcp/transport_security_common.proto
@@ -93,7 +93,7 @@ Mgrpc/testing/empty.proto=google.golang.org/grpc/interop/grpc_testing
for src in ${SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS}:${WORKDIR}/out \
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS},use_generic_streams_experimental=true:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
@@ -118,6 +118,6 @@ mv ${WORKDIR}/out/google.golang.org/grpc/lookup/grpc_lookup_v1/* ${WORKDIR}/out/
# grpc_testing_not_regenerate/*.pb.go are not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm ${WORKDIR}/out/google.golang.org/grpc/reflection/grpc_testing_not_regenerate/*.pb.go
rm ${WORKDIR}/out/google.golang.org/grpc/reflection/test/grpc_testing_not_regenerate/*.pb.go
cp -R ${WORKDIR}/out/google.golang.org/grpc/* .

View File

@@ -18,19 +18,43 @@
// Package dns implements a dns resolver to be installed as the default resolver
// in grpc.
//
// Deprecated: this package is imported by grpc and should not need to be
// imported directly by users.
package dns
import (
"time"
"google.golang.org/grpc/internal/resolver/dns"
"google.golang.org/grpc/resolver"
)
// SetResolvingTimeout sets the maximum duration for DNS resolution requests.
//
// This function affects the global timeout used by all channels using the DNS
// name resolver scheme.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
//
// The default value is 30 seconds. Setting the timeout too low may result in
// premature timeouts during resolution, while setting it too high may lead to
// unnecessary delays in service discovery. Choose a value appropriate for your
// specific needs and network environment.
func SetResolvingTimeout(timeout time.Duration) {
dns.ResolvingTimeout = timeout
}
// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
//
// Deprecated: import grpc and use resolver.Get("dns") instead.
func NewBuilder() resolver.Builder {
return dns.NewBuilder()
}
// SetMinResolutionInterval sets the default minimum interval at which DNS
// re-resolutions are allowed. This helps to prevent excessive re-resolution.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
func SetMinResolutionInterval(d time.Duration) {
dns.MinResolutionInterval = d
}

View File

@@ -29,6 +29,7 @@ import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig"
)
@@ -63,16 +64,18 @@ func Get(scheme string) Builder {
}
// SetDefaultScheme sets the default scheme that will be used. The default
// default scheme is "passthrough".
// scheme is initially set to "passthrough".
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. The scheme set last overrides
// previously set values.
func SetDefaultScheme(scheme string) {
defaultScheme = scheme
internal.UserSetDefaultScheme = true
}
// GetDefaultScheme gets the default scheme that will be used.
// GetDefaultScheme gets the default scheme that will be used by grpc.Dial. If
// SetDefaultScheme is never called, the default scheme used by grpc.NewClient is "dns" instead.
func GetDefaultScheme() string {
return defaultScheme
}
@@ -168,6 +171,9 @@ type BuildOptions struct {
// field. In most cases though, it is not appropriate, and this field may
// be ignored.
Dialer func(context.Context, string) (net.Conn, error)
// Authority is the effective authority of the clientconn for which the
// resolver is built.
Authority string
}
// An Endpoint is one network endpoint, or server, which may have multiple
@@ -281,9 +287,9 @@ func (t Target) Endpoint() string {
return strings.TrimPrefix(endpoint, "/")
}
// String returns a string representation of Target.
// String returns the canonical string representation of Target.
func (t Target) String() string {
return t.URL.String()
return t.URL.Scheme + "://" + t.URL.Host + "/" + t.Endpoint()
}
// Builder creates a resolver that will be used to watch name resolution updates.

View File

@@ -75,6 +75,7 @@ func (ccr *ccResolverWrapper) start() error {
DialCreds: ccr.cc.dopts.copts.TransportCredentials,
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
Dialer: ccr.cc.dopts.copts.Dialer,
Authority: ccr.cc.authority,
}
var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
@@ -96,7 +97,7 @@ func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
// finished shutting down, the channel should block on ccr.serializer.Done()
// without cc.mu held.
func (ccr *ccResolverWrapper) close() {
channelz.Info(logger, ccr.cc.channelzID, "Closing the name resolver")
channelz.Info(logger, ccr.cc.channelz, "Closing the name resolver")
ccr.mu.Lock()
ccr.closed = true
ccr.mu.Unlock()
@@ -146,7 +147,7 @@ func (ccr *ccResolverWrapper) ReportError(err error) {
return
}
ccr.mu.Unlock()
channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: reporting error to cc: %v", err)
channelz.Warningf(logger, ccr.cc.channelz, "ccResolverWrapper: reporting error to cc: %v", err)
ccr.cc.updateResolverStateAndUnlock(resolver.State{}, err)
}
@@ -170,7 +171,7 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
// ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON)
return parseServiceConfig(scJSON, ccr.cc.dopts.maxCallAttempts)
}
// addChannelzTraceEvent adds a channelz trace event containing the new
@@ -193,5 +194,5 @@ func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
} else if len(ccr.curState.Addresses) == 0 && len(s.Addresses) > 0 {
updates = append(updates, "resolver returned new addresses")
}
channelz.Infof(logger, ccr.cc.channelzID, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
channelz.Infof(logger, ccr.cc.channelz, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
}

View File

@@ -744,17 +744,19 @@ type payloadInfo struct {
uncompressedBytes []byte
}
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.compressedLength = len(buf)
return nil, nil, err
}
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return nil, st.Err()
return nil, nil, st.Err()
}
var size int
@@ -762,21 +764,35 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
size = len(uncompressedBuf)
} else {
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
} else {
uncompressedBuf = compressedBuf
}
if payInfo != nil {
payInfo.compressedLength = len(compressedBuf)
payInfo.uncompressedBytes = uncompressedBuf
cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
}
return buf, nil
return uncompressedBuf, cancel, nil
}
// Using compressor, decompress d, returning data and size.
@@ -796,6 +812,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// we can also utilize the recv buffer pool here.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
@@ -811,18 +830,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
defer cancel()
if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil
}
@@ -946,22 +962,9 @@ func setCallInfoCodec(c *callInfo) error {
return nil
}
// channelzData is used to store channelz related data for ClientConn, addrConn and Server.
// These fields cannot be embedded in the original structs (e.g. ClientConn), since to do atomic
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment.
type channelzData struct {
callsStarted int64
callsFailed int64
callsSucceeded int64
// lastCallStartedTime stores the timestamp that last call starts. It is of int64 type instead of
// time.Time since it's more costly to atomically update time.Time variable than int64 variable.
lastCallStartedTime int64
}
// The SupportPackageIsVersion variables are referenced from generated protocol
// buffer files to ensure compatibility with the gRPC version used. The latest
// support package version is 7.
// support package version is 9.
//
// Older versions are kept for compatibility.
//
@@ -973,6 +976,7 @@ const (
SupportPackageIsVersion6 = true
SupportPackageIsVersion7 = true
SupportPackageIsVersion8 = true
SupportPackageIsVersion9 = true
)
const grpcUA = "grpc-go/" + Version

View File

@@ -137,8 +137,7 @@ type Server struct {
serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop
handlersWG sync.WaitGroup // counts active method handler goroutines
channelzID *channelz.Identifier
czData *channelzData
channelz *channelz.Server
serverWorkerChannel chan func()
serverWorkerChannelClose func()
@@ -249,11 +248,9 @@ func SharedWriteBuffer(val bool) ServerOption {
}
// WriteBufferSize determines how much data can be batched before doing a write
// on the wire. The corresponding memory allocation for this buffer will be
// twice the size to keep syscalls low. The default value for this buffer is
// 32KB. Zero or negative values will disable the write buffer such that each
// write will be on underlying connection.
// Note: A Send call may not directly translate to a write.
// on the wire. The default value for this buffer is 32KB. Zero or negative
// values will disable the write buffer such that each write will be on underlying
// connection. Note: A Send call may not directly translate to a write.
func WriteBufferSize(s int) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.writeBufferSize = s
@@ -530,12 +527,22 @@ func ConnectionTimeout(d time.Duration) ServerOption {
})
}
// MaxHeaderListSizeServerOption is a ServerOption that sets the max
// (uncompressed) size of header list that the server is prepared to accept.
type MaxHeaderListSizeServerOption struct {
MaxHeaderListSize uint32
}
func (o MaxHeaderListSizeServerOption) apply(so *serverOptions) {
so.maxHeaderListSize = &o.MaxHeaderListSize
}
// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
// of header list that the server is prepared to accept.
func MaxHeaderListSize(s uint32) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.maxHeaderListSize = &s
})
return MaxHeaderListSizeServerOption{
MaxHeaderListSize: s,
}
}
// HeaderTableSize returns a ServerOption that sets the size of dynamic
@@ -661,7 +668,7 @@ func NewServer(opt ...ServerOption) *Server {
services: make(map[string]*serviceInfo),
quit: grpcsync.NewEvent(),
done: grpcsync.NewEvent(),
czData: new(channelzData),
channelz: channelz.RegisterServer(""),
}
chainUnaryServerInterceptors(s)
chainStreamServerInterceptors(s)
@@ -675,8 +682,7 @@ func NewServer(opt ...ServerOption) *Server {
s.initServerWorkers()
}
s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
channelz.Info(logger, s.channelzID, "Server created")
channelz.Info(logger, s.channelz, "Server created")
return s
}
@@ -802,20 +808,13 @@ var ErrServerStopped = errors.New("grpc: the server has been stopped")
type listenSocket struct {
net.Listener
channelzID *channelz.Identifier
}
func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
SocketOptions: channelz.GetSocketOption(l.Listener),
LocalAddr: l.Listener.Addr(),
}
channelz *channelz.Socket
}
func (l *listenSocket) Close() error {
err := l.Listener.Close()
channelz.RemoveEntry(l.channelzID)
channelz.Info(logger, l.channelzID, "ListenSocket deleted")
channelz.RemoveEntry(l.channelz.ID)
channelz.Info(logger, l.channelz, "ListenSocket deleted")
return err
}
@@ -857,7 +856,16 @@ func (s *Server) Serve(lis net.Listener) error {
}
}()
ls := &listenSocket{Listener: lis}
ls := &listenSocket{
Listener: lis,
channelz: channelz.RegisterSocket(&channelz.Socket{
SocketType: channelz.SocketTypeListen,
Parent: s.channelz,
RefName: lis.Addr().String(),
LocalAddr: lis.Addr(),
SocketOptions: channelz.GetSocketOption(lis)},
),
}
s.lis[ls] = true
defer func() {
@@ -869,14 +877,8 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Unlock()
}()
var err error
ls.channelzID, err = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String())
if err != nil {
s.mu.Unlock()
return err
}
s.mu.Unlock()
channelz.Info(logger, ls.channelzID, "ListenSocket created")
channelz.Info(logger, ls.channelz, "ListenSocket created")
var tempDelay time.Duration // how long to sleep on accept failure
for {
@@ -975,7 +977,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
WriteBufferSize: s.opts.writeBufferSize,
ReadBufferSize: s.opts.readBufferSize,
SharedWriteBuffer: s.opts.sharedWriteBuffer,
ChannelzParentID: s.channelzID,
ChannelzParent: s.channelz,
MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize,
}
@@ -989,7 +991,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
if err != credentials.ErrConnDispatched {
// Don't log on ErrConnDispatched and io.EOF to prevent log spam.
if err != io.EOF {
channelz.Info(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err)
channelz.Info(logger, s.channelz, "grpc: Server.Serve failed to create ServerTransport: ", err)
}
c.Close()
}
@@ -1121,37 +1123,28 @@ func (s *Server) removeConn(addr string, st transport.ServerTransport) {
}
}
func (s *Server) channelzMetric() *channelz.ServerInternalMetric {
return &channelz.ServerInternalMetric{
CallsStarted: atomic.LoadInt64(&s.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&s.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&s.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&s.czData.lastCallStartedTime)),
}
}
func (s *Server) incrCallsStarted() {
atomic.AddInt64(&s.czData.callsStarted, 1)
atomic.StoreInt64(&s.czData.lastCallStartedTime, time.Now().UnixNano())
s.channelz.ServerMetrics.CallsStarted.Add(1)
s.channelz.ServerMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
}
func (s *Server) incrCallsSucceeded() {
atomic.AddInt64(&s.czData.callsSucceeded, 1)
s.channelz.ServerMetrics.CallsSucceeded.Add(1)
}
func (s *Server) incrCallsFailed() {
atomic.AddInt64(&s.czData.callsFailed, 1)
s.channelz.ServerMetrics.CallsFailed.Add(1)
}
func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil {
channelz.Error(logger, s.channelzID, "grpc: server failed to encode response: ", err)
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
return err
}
compData, err := compress(data, cp, comp)
if err != nil {
channelz.Error(logger, s.channelzID, "grpc: server failed to compress response: ", err)
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
return err
}
hdr, payload := msgHeader(data, compData)
@@ -1342,10 +1335,11 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
}
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
}
return err
}
@@ -1353,6 +1347,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
t.IncrMsgRecv()
}
df := func(v any) error {
defer cancel()
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
@@ -1394,7 +1390,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
trInfo.tr.SetError()
}
if e := t.WriteStatus(stream, appStatus); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
}
if len(binlogs) != 0 {
if h, _ := stream.Header(); h.Len() > 0 {
@@ -1434,7 +1430,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
}
if sts, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, sts); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
}
} else {
switch st := err.(type) {
@@ -1762,7 +1758,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
@@ -1819,7 +1815,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
@@ -1891,8 +1887,7 @@ func (s *Server) stop(graceful bool) {
s.quit.Fire()
defer s.done.Fire()
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelz.ID) })
s.mu.Lock()
s.closeListenersLocked()
// Wait for serving threads to be ready to exit. Only then can we be sure no
@@ -2117,7 +2112,7 @@ func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}
return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
return stream.ClientAdvertisedCompressors(), nil
}
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
@@ -2147,17 +2142,9 @@ func Method(ctx context.Context) (string, bool) {
return s.Method(), true
}
type channelzServer struct {
s *Server
}
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}
// validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error {
func validateSendCompressor(name string, clientCompressors []string) error {
if name == encoding.Identity {
return nil
}
@@ -2166,7 +2153,7 @@ func validateSendCompressor(name, clientCompressors string) error {
return fmt.Errorf("compressor not registered %q", name)
}
for _, c := range strings.Split(clientCompressors, ",") {
for _, c := range clientCompressors {
if c == name {
return nil // found match
}

View File

@@ -25,8 +25,11 @@ import (
"reflect"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/serviceconfig"
)
@@ -41,11 +44,6 @@ const maxInt = int(^uint(0) >> 1)
// https://github.com/grpc/grpc/blob/master/doc/service_config.md
type MethodConfig = internalserviceconfig.MethodConfig
type lbConfig struct {
name string
cfg serviceconfig.LoadBalancingConfig
}
// ServiceConfig is provided by the service provider and contains parameters for how
// clients that connect to the service should behave.
//
@@ -55,14 +53,9 @@ type lbConfig struct {
type ServiceConfig struct {
serviceconfig.Config
// LB is the load balancer the service providers recommends. This is
// deprecated; lbConfigs is preferred. If lbConfig and LB are both present,
// lbConfig will be used.
LB *string
// lbConfig is the service config's load balancing configuration. If
// lbConfig and LB are both present, lbConfig will be used.
lbConfig *lbConfig
lbConfig serviceconfig.LoadBalancingConfig
// Methods contains a map for the methods in this service. If there is an
// exact match for a method (i.e. /service/method) in the map, use the
@@ -164,38 +157,55 @@ type jsonMC struct {
// TODO(lyuxuan): delete this struct after cleaning up old service config implementation.
type jsonSC struct {
LoadBalancingPolicy *string
LoadBalancingConfig *internalserviceconfig.BalancerConfig
LoadBalancingConfig *json.RawMessage
MethodConfig *[]jsonMC
RetryThrottling *retryThrottlingPolicy
HealthCheckConfig *healthCheckConfig
}
func init() {
internal.ParseServiceConfig = parseServiceConfig
internal.ParseServiceConfig = func(js string) *serviceconfig.ParseResult {
return parseServiceConfig(js, defaultMaxCallAttempts)
}
}
func parseServiceConfig(js string) *serviceconfig.ParseResult {
func parseServiceConfig(js string, maxAttempts int) *serviceconfig.ParseResult {
if len(js) == 0 {
return &serviceconfig.ParseResult{Err: fmt.Errorf("no JSON service config provided")}
}
var rsc jsonSC
err := json.Unmarshal([]byte(js), &rsc)
if err != nil {
logger.Warningf("grpc: unmarshaling service config %s: %v", js, err)
logger.Warningf("grpc: unmarshalling service config %s: %v", js, err)
return &serviceconfig.ParseResult{Err: err}
}
sc := ServiceConfig{
LB: rsc.LoadBalancingPolicy,
Methods: make(map[string]MethodConfig),
retryThrottling: rsc.RetryThrottling,
healthCheckConfig: rsc.HealthCheckConfig,
rawJSONString: js,
}
if c := rsc.LoadBalancingConfig; c != nil {
sc.lbConfig = &lbConfig{
name: c.Name,
cfg: c.Config,
c := rsc.LoadBalancingConfig
if c == nil {
name := pickfirst.Name
if rsc.LoadBalancingPolicy != nil {
name = *rsc.LoadBalancingPolicy
}
if balancer.Get(name) == nil {
name = pickfirst.Name
}
cfg := []map[string]any{{name: struct{}{}}}
strCfg, err := json.Marshal(cfg)
if err != nil {
return &serviceconfig.ParseResult{Err: fmt.Errorf("unexpected error marshaling simple LB config: %w", err)}
}
r := json.RawMessage(strCfg)
c = &r
}
cfg, err := gracefulswitch.ParseConfig(*c)
if err != nil {
return &serviceconfig.ParseResult{Err: err}
}
sc.lbConfig = cfg
if rsc.MethodConfig == nil {
return &serviceconfig.ParseResult{Config: &sc}
@@ -211,8 +221,8 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
WaitForReady: m.WaitForReady,
Timeout: (*time.Duration)(m.Timeout),
}
if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil {
logger.Warningf("grpc: unmarshaling service config %s: %v", js, err)
if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy, maxAttempts); err != nil {
logger.Warningf("grpc: unmarshalling service config %s: %v", js, err)
return &serviceconfig.ParseResult{Err: err}
}
if m.MaxRequestMessageBytes != nil {
@@ -232,13 +242,13 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
for i, n := range *m.Name {
path, err := n.generatePath()
if err != nil {
logger.Warningf("grpc: error unmarshaling service config %s due to methodConfig[%d]: %v", js, i, err)
logger.Warningf("grpc: error unmarshalling service config %s due to methodConfig[%d]: %v", js, i, err)
return &serviceconfig.ParseResult{Err: err}
}
if _, ok := paths[path]; ok {
err = errDuplicatedName
logger.Warningf("grpc: error unmarshaling service config %s due to methodConfig[%d]: %v", js, i, err)
logger.Warningf("grpc: error unmarshalling service config %s due to methodConfig[%d]: %v", js, i, err)
return &serviceconfig.ParseResult{Err: err}
}
paths[path] = struct{}{}
@@ -257,7 +267,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
return &serviceconfig.ParseResult{Config: &sc}
}
func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPolicy, err error) {
func convertRetryPolicy(jrp *jsonRetryPolicy, maxAttempts int) (p *internalserviceconfig.RetryPolicy, err error) {
if jrp == nil {
return nil, nil
}
@@ -271,17 +281,16 @@ func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPol
return nil, nil
}
if jrp.MaxAttempts < maxAttempts {
maxAttempts = jrp.MaxAttempts
}
rp := &internalserviceconfig.RetryPolicy{
MaxAttempts: jrp.MaxAttempts,
MaxAttempts: maxAttempts,
InitialBackoff: time.Duration(jrp.InitialBackoff),
MaxBackoff: time.Duration(jrp.MaxBackoff),
BackoffMultiplier: jrp.BackoffMultiplier,
RetryableStatusCodes: make(map[codes.Code]bool),
}
if rp.MaxAttempts > 5 {
// TODO(retry): Make the max maxAttempts configurable.
rp.MaxAttempts = 5
}
for _, code := range jrp.RetryableStatusCodes {
rp.RetryableStatusCodes[code] = true
}

View File

@@ -73,9 +73,12 @@ func (*PickerUpdated) isRPCStats() {}
type InPayload struct {
// Client is true if this InPayload is from client side.
Client bool
// Payload is the payload with original type.
// Payload is the payload with original type. This may be modified after
// the call to HandleRPC which provides the InPayload returns and must be
// copied if needed later.
Payload any
// Data is the serialized message payload.
// Deprecated: Data will be removed in the next release.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any
@@ -143,9 +146,12 @@ func (s *InTrailer) isRPCStats() {}
type OutPayload struct {
// Client is true if this OutPayload is from client side.
Client bool
// Payload is the payload with original type.
// Payload is the payload with original type. This may be modified after
// the call to HandleRPC which provides the OutPayload returns and must be
// copied if needed later.
Payload any
// Data is the serialized message payload.
// Deprecated: Data will be removed in the next release.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any
// framing (gRPC or HTTP/2).

View File

@@ -23,6 +23,7 @@ import (
"errors"
"io"
"math"
"math/rand"
"strconv"
"sync"
"time"
@@ -34,7 +35,6 @@ import (
"google.golang.org/grpc/internal/balancerload"
"google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
iresolver "google.golang.org/grpc/internal/resolver"
@@ -516,6 +516,7 @@ func (a *csAttempt) newStream() error {
return toRPCErr(nse.Err)
}
a.s = s
a.ctx = s.Context()
a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool}
return nil
}
@@ -655,13 +656,13 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) {
if len(sps) == 1 {
var e error
if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 {
channelz.Infof(logger, cs.cc.channelzID, "Server retry pushback specified to abort (%q).", sps[0])
channelz.Infof(logger, cs.cc.channelz, "Server retry pushback specified to abort (%q).", sps[0])
cs.retryThrottler.throttle() // This counts as a failure for throttling.
return false, err
}
hasPushback = true
} else if len(sps) > 1 {
channelz.Warningf(logger, cs.cc.channelzID, "Server retry pushback specified multiple values (%q); not retrying.", sps)
channelz.Warningf(logger, cs.cc.channelz, "Server retry pushback specified multiple values (%q); not retrying.", sps)
cs.retryThrottler.throttle() // This counts as a failure for throttling.
return false, err
}
@@ -698,7 +699,7 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) {
if max := float64(rp.MaxBackoff); cur > max {
cur = max
}
dur = time.Duration(grpcrand.Int63n(int64(cur)))
dur = time.Duration(rand.Int63n(int64(cur)))
cs.numRetriesSincePushback++
}

152
vendor/google.golang.org/grpc/stream_interfaces.go generated vendored Normal file
View File

@@ -0,0 +1,152 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
// ServerStreamingClient represents the client side of a server-streaming (one
// request, many responses) RPC. It is generic over the type of the response
// message. It is used in generated code.
type ServerStreamingClient[Res any] interface {
Recv() (*Res, error)
ClientStream
}
// ServerStreamingServer represents the server side of a server-streaming (one
// request, many responses) RPC. It is generic over the type of the response
// message. It is used in generated code.
type ServerStreamingServer[Res any] interface {
Send(*Res) error
ServerStream
}
// ClientStreamingClient represents the client side of a client-streaming (many
// requests, one response) RPC. It is generic over both the type of the request
// message stream and the type of the unary response message. It is used in
// generated code.
type ClientStreamingClient[Req any, Res any] interface {
Send(*Req) error
CloseAndRecv() (*Res, error)
ClientStream
}
// ClientStreamingServer represents the server side of a client-streaming (many
// requests, one response) RPC. It is generic over both the type of the request
// message stream and the type of the unary response message. It is used in
// generated code.
type ClientStreamingServer[Req any, Res any] interface {
Recv() (*Req, error)
SendAndClose(*Res) error
ServerStream
}
// BidiStreamingClient represents the client side of a bidirectional-streaming
// (many requests, many responses) RPC. It is generic over both the type of the
// request message stream and the type of the response message stream. It is
// used in generated code.
type BidiStreamingClient[Req any, Res any] interface {
Send(*Req) error
Recv() (*Res, error)
ClientStream
}
// BidiStreamingServer represents the server side of a bidirectional-streaming
// (many requests, many responses) RPC. It is generic over both the type of the
// request message stream and the type of the response message stream. It is
// used in generated code.
type BidiStreamingServer[Req any, Res any] interface {
Recv() (*Req, error)
Send(*Res) error
ServerStream
}
// GenericClientStream implements the ServerStreamingClient, ClientStreamingClient,
// and BidiStreamingClient interfaces. It is used in generated code.
type GenericClientStream[Req any, Res any] struct {
ClientStream
}
var _ ServerStreamingClient[string] = (*GenericClientStream[int, string])(nil)
var _ ClientStreamingClient[int, string] = (*GenericClientStream[int, string])(nil)
var _ BidiStreamingClient[int, string] = (*GenericClientStream[int, string])(nil)
// Send pushes one message into the stream of requests to be consumed by the
// server. The type of message which can be sent is determined by the Req type
// parameter of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) Send(m *Req) error {
return x.ClientStream.SendMsg(m)
}
// Recv reads one message from the stream of responses generated by the server.
// The type of the message returned is determined by the Res type parameter
// of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) Recv() (*Res, error) {
m := new(Res)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// CloseAndRecv closes the sending side of the stream, then receives the unary
// response from the server. The type of message which it returns is determined
// by the Res type parameter of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) CloseAndRecv() (*Res, error) {
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
m := new(Res)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// GenericServerStream implements the ServerStreamingServer, ClientStreamingServer,
// and BidiStreamingServer interfaces. It is used in generated code.
type GenericServerStream[Req any, Res any] struct {
ServerStream
}
var _ ServerStreamingServer[string] = (*GenericServerStream[int, string])(nil)
var _ ClientStreamingServer[int, string] = (*GenericServerStream[int, string])(nil)
var _ BidiStreamingServer[int, string] = (*GenericServerStream[int, string])(nil)
// Send pushes one message into the stream of responses to be consumed by the
// client. The type of message which can be sent is determined by the Res
// type parameter of the serverStreamServer receiver.
func (x *GenericServerStream[Req, Res]) Send(m *Res) error {
return x.ServerStream.SendMsg(m)
}
// SendAndClose pushes the unary response to the client. The type of message
// which can be sent is determined by the Res type parameter of the
// clientStreamServer receiver.
func (x *GenericServerStream[Req, Res]) SendAndClose(m *Res) error {
return x.ServerStream.SendMsg(m)
}
// Recv reads one message from the stream of requests generated by the client.
// The type of the message returned is determined by the Req type parameter
// of the clientStreamServer receiver.
func (x *GenericServerStream[Req, Res]) Recv() (*Req, error) {
m := new(Req)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -19,4 +19,4 @@
package grpc
// Version is the current grpc version.
const Version = "1.62.0"
const Version = "1.65.0"

190
vendor/google.golang.org/grpc/vet.sh generated vendored
View File

@@ -1,190 +0,0 @@
#!/bin/bash
set -ex # Exit on error; debugging enabled.
set -o pipefail # Fail a pipe if any sub-command fails.
# not makes sure the command passed to it does not exit with a return code of 0.
not() {
# This is required instead of the earlier (! $COMMAND) because subshells and
# pipefail don't work the same on Darwin as in Linux.
! "$@"
}
die() {
echo "$@" >&2
exit 1
}
fail_on_output() {
tee /dev/stderr | not read
}
# Check to make sure it's safe to modify the user's git repo.
git status --porcelain | fail_on_output
# Undo any edits made by this script.
cleanup() {
git reset --hard HEAD
}
trap cleanup EXIT
PATH="${HOME}/go/bin:${GOROOT}/bin:${PATH}"
go version
if [[ "$1" = "-install" ]]; then
# Install the pinned versions as defined in module tools.
pushd ./test/tools
go install \
golang.org/x/tools/cmd/goimports \
honnef.co/go/tools/cmd/staticcheck \
github.com/client9/misspell/cmd/misspell
popd
if [[ -z "${VET_SKIP_PROTO}" ]]; then
if [[ "${GITHUB_ACTIONS}" = "true" ]]; then
PROTOBUF_VERSION=25.2 # a.k.a. v4.22.0 in pb.go files.
PROTOC_FILENAME=protoc-${PROTOBUF_VERSION}-linux-x86_64.zip
pushd /home/runner/go
wget https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/${PROTOC_FILENAME}
unzip ${PROTOC_FILENAME}
bin/protoc --version
popd
elif not which protoc > /dev/null; then
die "Please install protoc into your path"
fi
fi
exit 0
elif [[ "$#" -ne 0 ]]; then
die "Unknown argument(s): $*"
fi
# - Check that generated proto files are up to date.
if [[ -z "${VET_SKIP_PROTO}" ]]; then
make proto && git status --porcelain 2>&1 | fail_on_output || \
(git status; git --no-pager diff; exit 1)
fi
if [[ -n "${VET_ONLY_PROTO}" ]]; then
exit 0
fi
# - Ensure all source files contain a copyright message.
# (Done in two parts because Darwin "git grep" has broken support for compound
# exclusion matches.)
(grep -L "DO NOT EDIT" $(git grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)" -- '*.go') || true) | fail_on_output
# - Make sure all tests in grpc and grpc/test use leakcheck via Teardown.
not grep 'func Test[^(]' *_test.go
not grep 'func Test[^(]' test/*.go
# - Check for typos in test function names
git grep 'func (s) ' -- "*_test.go" | not grep -v 'func (s) Test'
git grep 'func [A-Z]' -- "*_test.go" | not grep -v 'func Test\|Benchmark\|Example'
# - Do not import x/net/context.
not git grep -l 'x/net/context' -- "*.go"
# - Do not import math/rand for real library code. Use internal/grpcrand for
# thread safety.
git grep -l '"math/rand"' -- "*.go" 2>&1 | not grep -v '^examples\|^interop/stress\|grpcrand\|^benchmark\|wrr_test'
# - Do not use "interface{}"; use "any" instead.
git grep -l 'interface{}' -- "*.go" 2>&1 | not grep -v '\.pb\.go\|protoc-gen-go-grpc\|grpc_testing_not_regenerate'
# - Do not call grpclog directly. Use grpclog.Component instead.
git grep -l -e 'grpclog.I' --or -e 'grpclog.W' --or -e 'grpclog.E' --or -e 'grpclog.F' --or -e 'grpclog.V' -- "*.go" | not grep -v '^grpclog/component.go\|^internal/grpctest/tlogger_test.go'
# - Ensure all ptypes proto packages are renamed when importing.
not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go"
# - Ensure all usages of grpc_testing package are renamed when importing.
not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" -- "*.go"
# - Ensure all xds proto imports are renamed to *pb or *grpc.
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'
misspell -error .
# - gofmt, goimports, go vet, go mod tidy.
# Perform these checks on each module inside gRPC.
for MOD_FILE in $(find . -name 'go.mod'); do
MOD_DIR=$(dirname ${MOD_FILE})
pushd ${MOD_DIR}
go vet -all ./... | fail_on_output
gofmt -s -d -l . 2>&1 | fail_on_output
goimports -l . 2>&1 | not grep -vE "\.pb\.go"
go mod tidy -compat=1.19
git status --porcelain 2>&1 | fail_on_output || \
(git status; git --no-pager diff; exit 1)
popd
done
# - Collection of static analysis checks
SC_OUT="$(mktemp)"
staticcheck -go 1.19 -checks 'all' ./... > "${SC_OUT}" || true
# Error for anything other than checks that need exclusions.
grep -v "(ST1000)" "${SC_OUT}" | grep -v "(SA1019)" | grep -v "(ST1003)" | not grep -v "(ST1019)\|\(other import of\)"
# Exclude underscore checks for generated code.
grep "(ST1003)" "${SC_OUT}" | not grep -v '\(.pb.go:\)\|\(code_string_test.go:\)\|\(grpc_testing_not_regenerate\)'
# Error for duplicate imports not including grpc protos.
grep "(ST1019)\|\(other import of\)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
channelz/grpc_channelz_v1"
go-control-plane/envoy
grpclb/grpc_lb_v1"
health/grpc_health_v1"
interop/grpc_testing"
orca/v3"
proto/grpc_gcp"
proto/grpc_lookup_v1"
reflection/grpc_reflection_v1"
reflection/grpc_reflection_v1alpha"
XXXXX PleaseIgnoreUnused'
# Error for any package comments not in generated code.
grep "(ST1000)" "${SC_OUT}" | not grep -v "\.pb\.go:"
# Only ignore the following deprecated types/fields/functions and exclude
# generated code.
grep "(SA1019)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
XXXXX Protobuf related deprecation errors:
"github.com/golang/protobuf
.pb.go:
grpc_testing_not_regenerate
: ptypes.
proto.RegisterType
XXXXX gRPC internal usage deprecation errors:
"google.golang.org/grpc
: grpc.
: v1alpha.
: v1alphareflectionpb.
BalancerAttributes is deprecated:
CredsBundle is deprecated:
Metadata is deprecated: use Attributes instead.
NewSubConn is deprecated:
OverrideServerName is deprecated:
RemoveSubConn is deprecated:
SecurityVersion is deprecated:
Target is deprecated: Use the Target field in the BuildOptions instead.
UpdateAddresses is deprecated:
UpdateSubConnState is deprecated:
balancer.ErrTransientFailure is deprecated:
grpc/reflection/v1alpha/reflection.proto
XXXXX xDS deprecated fields we support
.ExactMatch
.PrefixMatch
.SafeRegexMatch
.SuffixMatch
GetContainsMatch
GetExactMatch
GetMatchSubjectAltNames
GetPrefixMatch
GetSafeRegexMatch
GetSuffixMatch
GetTlsCertificateCertificateProviderInstance
GetValidationContextCertificateProviderInstance
XXXXX PleaseIgnoreUnused'
echo SUCCESS

View File

@@ -102,7 +102,7 @@ type decoder struct {
}
// newError returns an error object with position info.
func (d decoder) newError(pos int, f string, x ...interface{}) error {
func (d decoder) newError(pos int, f string, x ...any) error {
line, column := d.Position(pos)
head := fmt.Sprintf("(line %d:%d): ", line, column)
return errors.New(head+f, x...)
@@ -114,7 +114,7 @@ func (d decoder) unexpectedTokenError(tok json.Token) error {
}
// syntaxError returns a syntax error for given position.
func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
func (d decoder) syntaxError(pos int, f string, x ...any) error {
line, column := d.Position(pos)
head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
return errors.New(head+f, x...)

View File

@@ -25,15 +25,17 @@ const defaultIndent = " "
// Format formats the message as a multiline string.
// This function is only intended for human consumption and ignores errors.
// Do not depend on the output being stable. It may change over time across
// different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func Format(m proto.Message) string {
return MarshalOptions{Multiline: true}.Format(m)
}
// Marshal writes the given [proto.Message] in JSON format using default options.
// Do not depend on the output being stable. It may change over time across
// different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func Marshal(m proto.Message) ([]byte, error) {
return MarshalOptions{}.Marshal(m)
}
@@ -110,8 +112,9 @@ type MarshalOptions struct {
// Format formats the message as a string.
// This method is only intended for human consumption and ignores errors.
// Do not depend on the output being stable. It may change over time across
// different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func (o MarshalOptions) Format(m proto.Message) string {
if m == nil || !m.ProtoReflect().IsValid() {
return "<nil>" // invalid syntax, but okay since this is for debugging
@@ -122,8 +125,9 @@ func (o MarshalOptions) Format(m proto.Message) string {
}
// Marshal marshals the given [proto.Message] in the JSON format using options in
// MarshalOptions. Do not depend on the output being stable. It may change over
// time across different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
return o.marshal(nil, m)
}

View File

@@ -84,7 +84,7 @@ type decoder struct {
}
// newError returns an error object with position info.
func (d decoder) newError(pos int, f string, x ...interface{}) error {
func (d decoder) newError(pos int, f string, x ...any) error {
line, column := d.Position(pos)
head := fmt.Sprintf("(line %d:%d): ", line, column)
return errors.New(head+f, x...)
@@ -96,7 +96,7 @@ func (d decoder) unexpectedTokenError(tok text.Token) error {
}
// syntaxError returns a syntax error for given position.
func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
func (d decoder) syntaxError(pos int, f string, x ...any) error {
line, column := d.Position(pos)
head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
return errors.New(head+f, x...)

View File

@@ -27,15 +27,17 @@ const defaultIndent = " "
// Format formats the message as a multiline string.
// This function is only intended for human consumption and ignores errors.
// Do not depend on the output being stable. It may change over time across
// different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func Format(m proto.Message) string {
return MarshalOptions{Multiline: true}.Format(m)
}
// Marshal writes the given [proto.Message] in textproto format using default
// options. Do not depend on the output being stable. It may change over time
// across different versions of the program.
// options. Do not depend on the output being stable. Its output will change
// across different builds of your program, even when using the same version of
// the protobuf module.
func Marshal(m proto.Message) ([]byte, error) {
return MarshalOptions{}.Marshal(m)
}
@@ -84,8 +86,9 @@ type MarshalOptions struct {
// Format formats the message as a string.
// This method is only intended for human consumption and ignores errors.
// Do not depend on the output being stable. It may change over time across
// different versions of the program.
// Do not depend on the output being stable. Its output will change across
// different builds of your program, even when using the same version of the
// protobuf module.
func (o MarshalOptions) Format(m proto.Message) string {
if m == nil || !m.ProtoReflect().IsValid() {
return "<nil>" // invalid syntax, but okay since this is for debugging
@@ -98,8 +101,9 @@ func (o MarshalOptions) Format(m proto.Message) string {
}
// Marshal writes the given [proto.Message] in textproto format using options in
// MarshalOptions object. Do not depend on the output being stable. It may
// change over time across different versions of the program.
// MarshalOptions object. Do not depend on the output being stable. Its output
// will change across different builds of your program, even when using the
// same version of the protobuf module.
func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
return o.marshal(nil, m)
}

View File

@@ -252,6 +252,7 @@ func formatDescOpt(t protoreflect.Descriptor, isRoot, allowMulti bool, record fu
{rv.MethodByName("Values"), "Values"},
{rv.MethodByName("ReservedNames"), "ReservedNames"},
{rv.MethodByName("ReservedRanges"), "ReservedRanges"},
{rv.MethodByName("IsClosed"), "IsClosed"},
}...)
case protoreflect.EnumValueDescriptor:

View File

@@ -0,0 +1,13 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package editionssupport defines constants for editions that are supported.
package editionssupport
import descriptorpb "google.golang.org/protobuf/types/descriptorpb"
const (
Minimum = descriptorpb.Edition_EDITION_PROTO2
Maximum = descriptorpb.Edition_EDITION_2023
)

View File

@@ -214,7 +214,7 @@ func (d *Decoder) parseNext() (Token, error) {
// newSyntaxError returns an error with line and column information useful for
// syntax errors.
func (d *Decoder) newSyntaxError(pos int, f string, x ...interface{}) error {
func (d *Decoder) newSyntaxError(pos int, f string, x ...any) error {
e := errors.New(f, x...)
line, column := d.Position(pos)
return errors.New("syntax error (line %d:%d): %v", line, column, e)

View File

@@ -32,6 +32,7 @@ var byteType = reflect.TypeOf(byte(0))
func Unmarshal(tag string, goType reflect.Type, evs protoreflect.EnumValueDescriptors) protoreflect.FieldDescriptor {
f := new(filedesc.Field)
f.L0.ParentFile = filedesc.SurrogateProto2
f.L1.EditionFeatures = f.L0.ParentFile.L1.EditionFeatures
for len(tag) > 0 {
i := strings.IndexByte(tag, ',')
if i < 0 {
@@ -107,8 +108,7 @@ func Unmarshal(tag string, goType reflect.Type, evs protoreflect.EnumValueDescri
f.L1.StringName.InitJSON(jsonName)
}
case s == "packed":
f.L1.HasPacked = true
f.L1.IsPacked = true
f.L1.EditionFeatures.IsPacked = true
case strings.HasPrefix(s, "weak="):
f.L1.IsWeak = true
f.L1.Message = filedesc.PlaceholderMessage(protoreflect.FullName(s[len("weak="):]))

View File

@@ -601,7 +601,7 @@ func (d *Decoder) consumeToken(kind Kind, size int, attrs uint8) Token {
// newSyntaxError returns a syntax error with line and column information for
// current position.
func (d *Decoder) newSyntaxError(f string, x ...interface{}) error {
func (d *Decoder) newSyntaxError(f string, x ...any) error {
e := errors.New(f, x...)
line, column := d.Position(len(d.orig) - len(d.in))
return errors.New("syntax error (line %d:%d): %v", line, column, e)

View File

@@ -17,7 +17,7 @@ var Error = errors.New("protobuf error")
// New formats a string according to the format specifier and arguments and
// returns an error that has a "proto" prefix.
func New(f string, x ...interface{}) error {
func New(f string, x ...any) error {
return &prefixError{s: format(f, x...)}
}
@@ -43,7 +43,7 @@ func (e *prefixError) Unwrap() error {
// Wrap returns an error that has a "proto" prefix, the formatted string described
// by the format specifier and arguments, and a suffix of err. The error wraps err.
func Wrap(err error, f string, x ...interface{}) error {
func Wrap(err error, f string, x ...any) error {
return &wrapError{
s: format(f, x...),
err: err,
@@ -67,7 +67,7 @@ func (e *wrapError) Is(target error) bool {
return target == Error
}
func format(f string, x ...interface{}) string {
func format(f string, x ...any) string {
// avoid "proto: " prefix when chaining
for i := 0; i < len(x); i++ {
switch e := x[i].(type) {
@@ -87,3 +87,18 @@ func InvalidUTF8(name string) error {
func RequiredNotSet(name string) error {
return New("required field %v not set", name)
}
type SizeMismatchError struct {
Calculated, Measured int
}
func (e *SizeMismatchError) Error() string {
return fmt.Sprintf("size mismatch (see https://github.com/golang/protobuf/issues/1609): calculated=%d, measured=%d", e.Calculated, e.Measured)
}
func MismatchedSizeCalculation(calculated, measured int) error {
return &SizeMismatchError{
Calculated: calculated,
Measured: measured,
}
}

View File

@@ -7,6 +7,7 @@ package filedesc
import (
"bytes"
"fmt"
"strings"
"sync"
"sync/atomic"
@@ -108,9 +109,12 @@ func (fd *File) ParentFile() protoreflect.FileDescriptor { return fd }
func (fd *File) Parent() protoreflect.Descriptor { return nil }
func (fd *File) Index() int { return 0 }
func (fd *File) Syntax() protoreflect.Syntax { return fd.L1.Syntax }
func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() }
func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package }
func (fd *File) IsPlaceholder() bool { return false }
// Not exported and just used to reconstruct the original FileDescriptor proto
func (fd *File) Edition() int32 { return int32(fd.L1.Edition) }
func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() }
func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package }
func (fd *File) IsPlaceholder() bool { return false }
func (fd *File) Options() protoreflect.ProtoMessage {
if f := fd.lazyInit().Options; f != nil {
return f()
@@ -202,6 +206,9 @@ func (ed *Enum) lazyInit() *EnumL2 {
ed.L0.ParentFile.lazyInit() // implicitly initializes L2
return ed.L2
}
func (ed *Enum) IsClosed() bool {
return !ed.L1.EditionFeatures.IsOpenEnum
}
func (ed *EnumValue) Options() protoreflect.ProtoMessage {
if f := ed.L1.Options; f != nil {
@@ -251,10 +258,6 @@ type (
StringName stringName
IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto
IsWeak bool // promoted from google.protobuf.FieldOptions
HasPacked bool // promoted from google.protobuf.FieldOptions
IsPacked bool // promoted from google.protobuf.FieldOptions
HasEnforceUTF8 bool // promoted from google.protobuf.FieldOptions
EnforceUTF8 bool // promoted from google.protobuf.FieldOptions
Default defaultValue
ContainingOneof protoreflect.OneofDescriptor // must be consistent with Message.Oneofs.Fields
Enum protoreflect.EnumDescriptor
@@ -331,8 +334,7 @@ func (fd *Field) HasPresence() bool {
if fd.L1.Cardinality == protoreflect.Repeated {
return false
}
explicitFieldPresence := fd.Syntax() == protoreflect.Editions && fd.L1.EditionFeatures.IsFieldPresence
return fd.Syntax() == protoreflect.Proto2 || explicitFieldPresence || fd.L1.Message != nil || fd.L1.ContainingOneof != nil
return fd.IsExtension() || fd.L1.EditionFeatures.IsFieldPresence || fd.L1.Message != nil || fd.L1.ContainingOneof != nil
}
func (fd *Field) HasOptionalKeyword() bool {
return (fd.L0.ParentFile.L1.Syntax == protoreflect.Proto2 && fd.L1.Cardinality == protoreflect.Optional && fd.L1.ContainingOneof == nil) || fd.L1.IsProto3Optional
@@ -345,14 +347,7 @@ func (fd *Field) IsPacked() bool {
case protoreflect.StringKind, protoreflect.BytesKind, protoreflect.MessageKind, protoreflect.GroupKind:
return false
}
if fd.L0.ParentFile.L1.Syntax == protoreflect.Editions {
return fd.L1.EditionFeatures.IsPacked
}
if fd.L0.ParentFile.L1.Syntax == protoreflect.Proto3 {
// proto3 repeated fields are packed by default.
return !fd.L1.HasPacked || fd.L1.IsPacked
}
return fd.L1.IsPacked
return fd.L1.EditionFeatures.IsPacked
}
func (fd *Field) IsExtension() bool { return false }
func (fd *Field) IsWeak() bool { return fd.L1.IsWeak }
@@ -388,6 +383,10 @@ func (fd *Field) Message() protoreflect.MessageDescriptor {
}
return fd.L1.Message
}
func (fd *Field) IsMapEntry() bool {
parent, ok := fd.L0.Parent.(protoreflect.MessageDescriptor)
return ok && parent.IsMapEntry()
}
func (fd *Field) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, fd) }
func (fd *Field) ProtoType(protoreflect.FieldDescriptor) {}
@@ -399,13 +398,7 @@ func (fd *Field) ProtoType(protoreflect.FieldDescriptor) {}
// WARNING: This method is exempt from the compatibility promise and may be
// removed in the future without warning.
func (fd *Field) EnforceUTF8() bool {
if fd.L0.ParentFile.L1.Syntax == protoreflect.Editions {
return fd.L1.EditionFeatures.IsUTF8Validated
}
if fd.L1.HasEnforceUTF8 {
return fd.L1.EnforceUTF8
}
return fd.L0.ParentFile.L1.Syntax == protoreflect.Proto3
return fd.L1.EditionFeatures.IsUTF8Validated
}
func (od *Oneof) IsSynthetic() bool {
@@ -438,7 +431,6 @@ type (
Options func() protoreflect.ProtoMessage
StringName stringName
IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto
IsPacked bool // promoted from google.protobuf.FieldOptions
Default defaultValue
Enum protoreflect.EnumDescriptor
Message protoreflect.MessageDescriptor
@@ -461,7 +453,16 @@ func (xd *Extension) HasPresence() bool { return xd.L1.Cardi
func (xd *Extension) HasOptionalKeyword() bool {
return (xd.L0.ParentFile.L1.Syntax == protoreflect.Proto2 && xd.L1.Cardinality == protoreflect.Optional) || xd.lazyInit().IsProto3Optional
}
func (xd *Extension) IsPacked() bool { return xd.lazyInit().IsPacked }
func (xd *Extension) IsPacked() bool {
if xd.L1.Cardinality != protoreflect.Repeated {
return false
}
switch xd.L1.Kind {
case protoreflect.StringKind, protoreflect.BytesKind, protoreflect.MessageKind, protoreflect.GroupKind:
return false
}
return xd.L1.EditionFeatures.IsPacked
}
func (xd *Extension) IsExtension() bool { return true }
func (xd *Extension) IsWeak() bool { return false }
func (xd *Extension) IsList() bool { return xd.Cardinality() == protoreflect.Repeated }
@@ -542,8 +543,9 @@ func (md *Method) ProtoInternal(pragma.DoNotImplement) {}
// Surrogate files are can be used to create standalone descriptors
// where the syntax is only information derived from the parent file.
var (
SurrogateProto2 = &File{L1: FileL1{Syntax: protoreflect.Proto2}, L2: &FileL2{}}
SurrogateProto3 = &File{L1: FileL1{Syntax: protoreflect.Proto3}, L2: &FileL2{}}
SurrogateProto2 = &File{L1: FileL1{Syntax: protoreflect.Proto2}, L2: &FileL2{}}
SurrogateProto3 = &File{L1: FileL1{Syntax: protoreflect.Proto3}, L2: &FileL2{}}
SurrogateEdition2023 = &File{L1: FileL1{Syntax: protoreflect.Editions, Edition: Edition2023}, L2: &FileL2{}}
)
type (
@@ -585,6 +587,34 @@ func (s *stringName) InitJSON(name string) {
s.nameJSON = name
}
// Returns true if this field is structured like the synthetic field of a proto2
// group. This allows us to expand our treatment of delimited fields without
// breaking proto2 files that have been upgraded to editions.
func isGroupLike(fd protoreflect.FieldDescriptor) bool {
// Groups are always group types.
if fd.Kind() != protoreflect.GroupKind {
return false
}
// Group fields are always the lowercase type name.
if strings.ToLower(string(fd.Message().Name())) != string(fd.Name()) {
return false
}
// Groups could only be defined in the same file they're used.
if fd.Message().ParentFile() != fd.ParentFile() {
return false
}
// Group messages are always defined in the same scope as the field. File
// level extensions will compare NULL == NULL here, which is why the file
// comparison above is necessary to ensure both come from the same file.
if fd.IsExtension() {
return fd.Parent() == fd.Message().Parent()
}
return fd.ContainingMessage() == fd.Message().Parent()
}
func (s *stringName) lazyInit(fd protoreflect.FieldDescriptor) *stringName {
s.once.Do(func() {
if fd.IsExtension() {
@@ -605,7 +635,7 @@ func (s *stringName) lazyInit(fd protoreflect.FieldDescriptor) *stringName {
// Format the text name.
s.nameText = string(fd.Name())
if fd.Kind() == protoreflect.GroupKind {
if isGroupLike(fd) {
s.nameText = string(fd.Message().Name())
}
}

View File

@@ -113,8 +113,10 @@ func (fd *File) unmarshalSeed(b []byte) {
switch string(v) {
case "proto2":
fd.L1.Syntax = protoreflect.Proto2
fd.L1.Edition = EditionProto2
case "proto3":
fd.L1.Syntax = protoreflect.Proto3
fd.L1.Edition = EditionProto3
case "editions":
fd.L1.Syntax = protoreflect.Editions
default:
@@ -177,11 +179,10 @@ func (fd *File) unmarshalSeed(b []byte) {
// If syntax is missing, it is assumed to be proto2.
if fd.L1.Syntax == 0 {
fd.L1.Syntax = protoreflect.Proto2
fd.L1.Edition = EditionProto2
}
if fd.L1.Syntax == protoreflect.Editions {
fd.L1.EditionFeatures = getFeaturesFor(fd.L1.Edition)
}
fd.L1.EditionFeatures = getFeaturesFor(fd.L1.Edition)
// Parse editions features from options if any
if options != nil {
@@ -267,6 +268,7 @@ func (ed *Enum) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protorefl
ed.L0.ParentFile = pf
ed.L0.Parent = pd
ed.L0.Index = i
ed.L1.EditionFeatures = featuresFromParentDesc(ed.Parent())
var numValues int
for b := b; len(b) > 0; {
@@ -443,6 +445,7 @@ func (xd *Extension) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd prot
xd.L0.ParentFile = pf
xd.L0.Parent = pd
xd.L0.Index = i
xd.L1.EditionFeatures = featuresFromParentDesc(pd)
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
@@ -467,6 +470,38 @@ func (xd *Extension) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd prot
xd.L0.FullName = appendFullName(sb, pd.FullName(), v)
case genid.FieldDescriptorProto_Extendee_field_number:
xd.L1.Extendee = PlaceholderMessage(makeFullName(sb, v))
case genid.FieldDescriptorProto_Options_field_number:
xd.unmarshalOptions(v)
}
default:
m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:]
}
}
if xd.L1.Kind == protoreflect.MessageKind && xd.L1.EditionFeatures.IsDelimitedEncoded {
xd.L1.Kind = protoreflect.GroupKind
}
}
func (xd *Extension) unmarshalOptions(b []byte) {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
b = b[n:]
switch typ {
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.FieldOptions_Packed_field_number:
xd.L1.EditionFeatures.IsPacked = protowire.DecodeBool(v)
}
case protowire.BytesType:
v, m := protowire.ConsumeBytes(b)
b = b[m:]
switch num {
case genid.FieldOptions_Features_field_number:
xd.L1.EditionFeatures = unmarshalFeatureSet(v, xd.L1.EditionFeatures)
}
default:
m := protowire.ConsumeFieldValue(num, typ, b)
@@ -499,7 +534,7 @@ func (sd *Service) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protor
}
var nameBuilderPool = sync.Pool{
New: func() interface{} { return new(strs.Builder) },
New: func() any { return new(strs.Builder) },
}
func getBuilder() *strs.Builder {

View File

@@ -45,6 +45,11 @@ func (file *File) resolveMessages() {
case protoreflect.MessageKind, protoreflect.GroupKind:
fd.L1.Message = file.resolveMessageDependency(fd.L1.Message, listFieldDeps, depIdx)
depIdx++
if fd.L1.Kind == protoreflect.GroupKind && (fd.IsMap() || fd.IsMapEntry()) {
// A map field might inherit delimited encoding from a file-wide default feature.
// But maps never actually use delimited encoding. (At least for now...)
fd.L1.Kind = protoreflect.MessageKind
}
}
// Default is resolved here since it depends on Enum being resolved.
@@ -466,10 +471,10 @@ func (fd *Field) unmarshalFull(b []byte, sb *strs.Builder, pf *File, pd protoref
b = b[m:]
}
}
if fd.Syntax() == protoreflect.Editions && fd.L1.Kind == protoreflect.MessageKind && fd.L1.EditionFeatures.IsDelimitedEncoded {
if fd.L1.Kind == protoreflect.MessageKind && fd.L1.EditionFeatures.IsDelimitedEncoded {
fd.L1.Kind = protoreflect.GroupKind
}
if fd.Syntax() == protoreflect.Editions && fd.L1.EditionFeatures.IsLegacyRequired {
if fd.L1.EditionFeatures.IsLegacyRequired {
fd.L1.Cardinality = protoreflect.Required
}
if rawTypeName != nil {
@@ -496,13 +501,11 @@ func (fd *Field) unmarshalOptions(b []byte) {
b = b[m:]
switch num {
case genid.FieldOptions_Packed_field_number:
fd.L1.HasPacked = true
fd.L1.IsPacked = protowire.DecodeBool(v)
fd.L1.EditionFeatures.IsPacked = protowire.DecodeBool(v)
case genid.FieldOptions_Weak_field_number:
fd.L1.IsWeak = protowire.DecodeBool(v)
case FieldOptions_EnforceUTF8:
fd.L1.HasEnforceUTF8 = true
fd.L1.EnforceUTF8 = protowire.DecodeBool(v)
fd.L1.EditionFeatures.IsUTF8Validated = protowire.DecodeBool(v)
}
case protowire.BytesType:
v, m := protowire.ConsumeBytes(b)
@@ -548,7 +551,6 @@ func (od *Oneof) unmarshalFull(b []byte, sb *strs.Builder, pf *File, pd protoref
func (xd *Extension) unmarshalFull(b []byte, sb *strs.Builder) {
var rawTypeName []byte
var rawOptions []byte
xd.L1.EditionFeatures = featuresFromParentDesc(xd.L1.Extendee)
xd.L2 = new(ExtensionL2)
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
@@ -572,7 +574,6 @@ func (xd *Extension) unmarshalFull(b []byte, sb *strs.Builder) {
case genid.FieldDescriptorProto_TypeName_field_number:
rawTypeName = v
case genid.FieldDescriptorProto_Options_field_number:
xd.unmarshalOptions(v)
rawOptions = appendOptions(rawOptions, v)
}
default:
@@ -580,12 +581,6 @@ func (xd *Extension) unmarshalFull(b []byte, sb *strs.Builder) {
b = b[m:]
}
}
if xd.Syntax() == protoreflect.Editions && xd.L1.Kind == protoreflect.MessageKind && xd.L1.EditionFeatures.IsDelimitedEncoded {
xd.L1.Kind = protoreflect.GroupKind
}
if xd.Syntax() == protoreflect.Editions && xd.L1.EditionFeatures.IsLegacyRequired {
xd.L1.Cardinality = protoreflect.Required
}
if rawTypeName != nil {
name := makeFullName(sb, rawTypeName)
switch xd.L1.Kind {
@@ -598,32 +593,6 @@ func (xd *Extension) unmarshalFull(b []byte, sb *strs.Builder) {
xd.L2.Options = xd.L0.ParentFile.builder.optionsUnmarshaler(&descopts.Field, rawOptions)
}
func (xd *Extension) unmarshalOptions(b []byte) {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
b = b[n:]
switch typ {
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.FieldOptions_Packed_field_number:
xd.L2.IsPacked = protowire.DecodeBool(v)
}
case protowire.BytesType:
v, m := protowire.ConsumeBytes(b)
b = b[m:]
switch num {
case genid.FieldOptions_Features_field_number:
xd.L1.EditionFeatures = unmarshalFeatureSet(v, xd.L1.EditionFeatures)
}
default:
m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:]
}
}
}
func (sd *Service) unmarshalFull(b []byte, sb *strs.Builder) {
var rawMethods [][]byte
var rawOptions []byte

View File

@@ -8,6 +8,7 @@ package filedesc
import (
"fmt"
"strings"
"sync"
"google.golang.org/protobuf/internal/descfmt"
@@ -198,6 +199,16 @@ func (p *Fields) lazyInit() *Fields {
if _, ok := p.byText[d.TextName()]; !ok {
p.byText[d.TextName()] = d
}
if isGroupLike(d) {
lowerJSONName := strings.ToLower(d.JSONName())
if _, ok := p.byJSON[lowerJSONName]; !ok {
p.byJSON[lowerJSONName] = d
}
lowerTextName := strings.ToLower(d.TextName())
if _, ok := p.byText[lowerTextName]; !ok {
p.byText[lowerTextName] = d
}
}
if _, ok := p.byNum[d.Number()]; !ok {
p.byNum[d.Number()] = d
}

View File

@@ -14,9 +14,13 @@ import (
)
var defaultsCache = make(map[Edition]EditionFeatures)
var defaultsKeys = []Edition{}
func init() {
unmarshalEditionDefaults(editiondefaults.Defaults)
SurrogateProto2.L1.EditionFeatures = getFeaturesFor(EditionProto2)
SurrogateProto3.L1.EditionFeatures = getFeaturesFor(EditionProto3)
SurrogateEdition2023.L1.EditionFeatures = getFeaturesFor(Edition2023)
}
func unmarshalGoFeature(b []byte, parent EditionFeatures) EditionFeatures {
@@ -104,12 +108,15 @@ func unmarshalEditionDefault(b []byte) {
v, m := protowire.ConsumeBytes(b)
b = b[m:]
switch num {
case genid.FeatureSetDefaults_FeatureSetEditionDefault_Features_field_number:
case genid.FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_number:
fs = unmarshalFeatureSet(v, fs)
case genid.FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_number:
fs = unmarshalFeatureSet(v, fs)
}
}
}
defaultsCache[ed] = fs
defaultsKeys = append(defaultsKeys, ed)
}
func unmarshalEditionDefaults(b []byte) {
@@ -135,8 +142,15 @@ func unmarshalEditionDefaults(b []byte) {
}
func getFeaturesFor(ed Edition) EditionFeatures {
if def, ok := defaultsCache[ed]; ok {
return def
match := EditionUnknown
for _, key := range defaultsKeys {
if key > ed {
break
}
match = key
}
panic(fmt.Sprintf("unsupported edition: %v", ed))
if match == EditionUnknown {
panic(fmt.Sprintf("unsupported edition: %v", ed))
}
return defaultsCache[match]
}

View File

@@ -63,6 +63,7 @@ func (e PlaceholderEnum) Options() protoreflect.ProtoMessage { return des
func (e PlaceholderEnum) Values() protoreflect.EnumValueDescriptors { return emptyEnumValues }
func (e PlaceholderEnum) ReservedNames() protoreflect.Names { return emptyNames }
func (e PlaceholderEnum) ReservedRanges() protoreflect.EnumRanges { return emptyEnumRanges }
func (e PlaceholderEnum) IsClosed() bool { return false }
func (e PlaceholderEnum) ProtoType(protoreflect.EnumDescriptor) { return }
func (e PlaceholderEnum) ProtoInternal(pragma.DoNotImplement) { return }

View File

@@ -68,7 +68,7 @@ type Builder struct {
// and for input and output messages referenced by service methods.
// Dependencies must come after declarations, but the ordering of
// dependencies themselves is unspecified.
GoTypes []interface{}
GoTypes []any
// DependencyIndexes is an ordered list of indexes into GoTypes for the
// dependencies of messages, extensions, or services.
@@ -268,7 +268,7 @@ func (x depIdxs) Get(i, j int32) int32 {
type (
resolverByIndex struct {
goTypes []interface{}
goTypes []any
depIdxs depIdxs
fileRegistry
}

View File

@@ -21,6 +21,7 @@ const (
// Enum values for google.protobuf.Edition.
const (
Edition_EDITION_UNKNOWN_enum_value = 0
Edition_EDITION_LEGACY_enum_value = 900
Edition_EDITION_PROTO2_enum_value = 998
Edition_EDITION_PROTO3_enum_value = 999
Edition_EDITION_2023_enum_value = 1000
@@ -653,6 +654,7 @@ const (
FieldOptions_Targets_field_name protoreflect.Name = "targets"
FieldOptions_EditionDefaults_field_name protoreflect.Name = "edition_defaults"
FieldOptions_Features_field_name protoreflect.Name = "features"
FieldOptions_FeatureSupport_field_name protoreflect.Name = "feature_support"
FieldOptions_UninterpretedOption_field_name protoreflect.Name = "uninterpreted_option"
FieldOptions_Ctype_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.ctype"
@@ -667,6 +669,7 @@ const (
FieldOptions_Targets_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.targets"
FieldOptions_EditionDefaults_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.edition_defaults"
FieldOptions_Features_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.features"
FieldOptions_FeatureSupport_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.feature_support"
FieldOptions_UninterpretedOption_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.uninterpreted_option"
)
@@ -684,6 +687,7 @@ const (
FieldOptions_Targets_field_number protoreflect.FieldNumber = 19
FieldOptions_EditionDefaults_field_number protoreflect.FieldNumber = 20
FieldOptions_Features_field_number protoreflect.FieldNumber = 21
FieldOptions_FeatureSupport_field_number protoreflect.FieldNumber = 22
FieldOptions_UninterpretedOption_field_number protoreflect.FieldNumber = 999
)
@@ -767,6 +771,33 @@ const (
FieldOptions_EditionDefault_Value_field_number protoreflect.FieldNumber = 2
)
// Names for google.protobuf.FieldOptions.FeatureSupport.
const (
FieldOptions_FeatureSupport_message_name protoreflect.Name = "FeatureSupport"
FieldOptions_FeatureSupport_message_fullname protoreflect.FullName = "google.protobuf.FieldOptions.FeatureSupport"
)
// Field names for google.protobuf.FieldOptions.FeatureSupport.
const (
FieldOptions_FeatureSupport_EditionIntroduced_field_name protoreflect.Name = "edition_introduced"
FieldOptions_FeatureSupport_EditionDeprecated_field_name protoreflect.Name = "edition_deprecated"
FieldOptions_FeatureSupport_DeprecationWarning_field_name protoreflect.Name = "deprecation_warning"
FieldOptions_FeatureSupport_EditionRemoved_field_name protoreflect.Name = "edition_removed"
FieldOptions_FeatureSupport_EditionIntroduced_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.FeatureSupport.edition_introduced"
FieldOptions_FeatureSupport_EditionDeprecated_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.FeatureSupport.edition_deprecated"
FieldOptions_FeatureSupport_DeprecationWarning_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.FeatureSupport.deprecation_warning"
FieldOptions_FeatureSupport_EditionRemoved_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.FeatureSupport.edition_removed"
)
// Field numbers for google.protobuf.FieldOptions.FeatureSupport.
const (
FieldOptions_FeatureSupport_EditionIntroduced_field_number protoreflect.FieldNumber = 1
FieldOptions_FeatureSupport_EditionDeprecated_field_number protoreflect.FieldNumber = 2
FieldOptions_FeatureSupport_DeprecationWarning_field_number protoreflect.FieldNumber = 3
FieldOptions_FeatureSupport_EditionRemoved_field_number protoreflect.FieldNumber = 4
)
// Names for google.protobuf.OneofOptions.
const (
OneofOptions_message_name protoreflect.Name = "OneofOptions"
@@ -829,11 +860,13 @@ const (
EnumValueOptions_Deprecated_field_name protoreflect.Name = "deprecated"
EnumValueOptions_Features_field_name protoreflect.Name = "features"
EnumValueOptions_DebugRedact_field_name protoreflect.Name = "debug_redact"
EnumValueOptions_FeatureSupport_field_name protoreflect.Name = "feature_support"
EnumValueOptions_UninterpretedOption_field_name protoreflect.Name = "uninterpreted_option"
EnumValueOptions_Deprecated_field_fullname protoreflect.FullName = "google.protobuf.EnumValueOptions.deprecated"
EnumValueOptions_Features_field_fullname protoreflect.FullName = "google.protobuf.EnumValueOptions.features"
EnumValueOptions_DebugRedact_field_fullname protoreflect.FullName = "google.protobuf.EnumValueOptions.debug_redact"
EnumValueOptions_FeatureSupport_field_fullname protoreflect.FullName = "google.protobuf.EnumValueOptions.feature_support"
EnumValueOptions_UninterpretedOption_field_fullname protoreflect.FullName = "google.protobuf.EnumValueOptions.uninterpreted_option"
)
@@ -842,6 +875,7 @@ const (
EnumValueOptions_Deprecated_field_number protoreflect.FieldNumber = 1
EnumValueOptions_Features_field_number protoreflect.FieldNumber = 2
EnumValueOptions_DebugRedact_field_number protoreflect.FieldNumber = 3
EnumValueOptions_FeatureSupport_field_number protoreflect.FieldNumber = 4
EnumValueOptions_UninterpretedOption_field_number protoreflect.FieldNumber = 999
)
@@ -1110,17 +1144,20 @@ const (
// Field names for google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.
const (
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_name protoreflect.Name = "edition"
FeatureSetDefaults_FeatureSetEditionDefault_Features_field_name protoreflect.Name = "features"
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_name protoreflect.Name = "edition"
FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_name protoreflect.Name = "overridable_features"
FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_name protoreflect.Name = "fixed_features"
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.edition"
FeatureSetDefaults_FeatureSetEditionDefault_Features_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.features"
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.edition"
FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.overridable_features"
FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.fixed_features"
)
// Field numbers for google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.
const (
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_number protoreflect.FieldNumber = 3
FeatureSetDefaults_FeatureSetEditionDefault_Features_field_number protoreflect.FieldNumber = 2
FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_number protoreflect.FieldNumber = 3
FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_number protoreflect.FieldNumber = 4
FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_number protoreflect.FieldNumber = 5
)
// Names for google.protobuf.SourceCodeInfo.

View File

@@ -10,7 +10,7 @@ import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
)
const File_reflect_protodesc_proto_go_features_proto = "reflect/protodesc/proto/go_features.proto"
const File_google_protobuf_go_features_proto = "google/protobuf/go_features.proto"
// Names for google.protobuf.GoFeatures.
const (

View File

@@ -22,13 +22,13 @@ type Export struct{}
// NewError formats a string according to the format specifier and arguments and
// returns an error that has a "proto" prefix.
func (Export) NewError(f string, x ...interface{}) error {
func (Export) NewError(f string, x ...any) error {
return errors.New(f, x...)
}
// enum is any enum type generated by protoc-gen-go
// and must be a named int32 type.
type enum = interface{}
type enum = any
// EnumOf returns the protoreflect.Enum interface over e.
// It returns nil if e is nil.
@@ -81,7 +81,7 @@ func (Export) EnumStringOf(ed protoreflect.EnumDescriptor, n protoreflect.EnumNu
// message is any message type generated by protoc-gen-go
// and must be a pointer to a named struct type.
type message = interface{}
type message = any
// legacyMessageWrapper wraps a v2 message as a v1 message.
type legacyMessageWrapper struct{ m protoreflect.ProtoMessage }

View File

@@ -68,7 +68,7 @@ func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
}
for _, x := range *ext {
ei := getExtensionFieldInfo(x.Type())
if ei.funcs.isInit == nil {
if ei.funcs.isInit == nil || x.isUnexpandedLazy() {
continue
}
v := x.Value()

View File

@@ -99,6 +99,28 @@ func (f *ExtensionField) canLazy(xt protoreflect.ExtensionType) bool {
return false
}
// isUnexpandedLazy returns true if the ExensionField is lazy and not
// yet expanded, which means it's present and already checked for
// initialized required fields.
func (f *ExtensionField) isUnexpandedLazy() bool {
return f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0
}
// lazyBuffer retrieves the buffer for a lazy extension if it's not yet expanded.
//
// The returned buffer has to be kept over whatever operation we're planning,
// as re-retrieving it will fail after the message is lazily decoded.
func (f *ExtensionField) lazyBuffer() []byte {
// This function might be in the critical path, so check the atomic without
// taking a look first, then only take the lock if needed.
if !f.isUnexpandedLazy() {
return nil
}
f.lazy.mu.Lock()
defer f.lazy.mu.Unlock()
return f.lazy.b
}
func (f *ExtensionField) lazyInit() {
f.lazy.mu.Lock()
defer f.lazy.mu.Unlock()

View File

@@ -233,9 +233,15 @@ func sizeMessageInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
}
func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
calculatedSize := f.mi.sizePointer(p.Elem(), opts)
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendVarint(b, uint64(f.mi.sizePointer(p.Elem(), opts)))
return f.mi.marshalAppendPointer(b, p.Elem(), opts)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
@@ -262,14 +268,21 @@ func isInitMessageInfo(p pointer, f *coderFieldInfo) error {
return f.mi.checkInitializedPointer(p.Elem())
}
func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
return protowire.SizeBytes(proto.Size(m)) + tagsize
func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int {
return protowire.SizeBytes(opts.Options().Size(m)) + tagsize
}
func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
mopts := opts.Options()
calculatedSize := mopts.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(proto.Size(m)))
return opts.Options().MarshalAppend(b, m)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := mopts.MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
@@ -405,8 +418,8 @@ func consumeGroupType(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf
return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
}
func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
return 2*tagsize + proto.Size(m)
func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int {
return 2*tagsize + opts.Options().Size(m)
}
func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
@@ -482,10 +495,14 @@ func appendMessageSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshal
b = protowire.AppendVarint(b, f.wiretag)
siz := f.mi.sizePointer(v, opts)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
@@ -520,28 +537,34 @@ func isInitMessageSliceInfo(p pointer, f *coderFieldInfo) error {
return nil
}
func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
n += protowire.SizeBytes(proto.Size(m)) + tagsize
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
mopts := opts.Options()
s := p.PointerSlice()
var err error
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
b, err = opts.Options().MarshalAppend(b, m)
before := len(b)
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
@@ -582,11 +605,12 @@ func isInitMessageSlice(p pointer, goType reflect.Type) error {
// Slices of messages
func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += protowire.SizeBytes(proto.Size(m)) + tagsize
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
@@ -597,13 +621,17 @@ func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64,
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
@@ -651,11 +679,12 @@ var coderMessageSliceValue = valueCoderFuncs{
}
func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += 2*tagsize + proto.Size(m)
n += 2*tagsize + mopts.Size(m)
}
return n
}
@@ -738,12 +767,13 @@ func makeGroupSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type)
}
}
func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int {
func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(messageType.Elem()))
n += 2*tagsize + proto.Size(m)
n += 2*tagsize + mopts.Size(m)
}
return n
}

View File

@@ -9,6 +9,7 @@ import (
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
)
@@ -240,11 +241,16 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(val, mapValTagSize, opts)
b = protowire.AppendVarint(b, uint64(size))
before := len(b)
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return nil, err
}
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
if measuredSize := len(b) - before; size != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(size, measuredSize)
}
return b, err
} else {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := pointerOfValue(valrv)
@@ -259,7 +265,12 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
}
b = protowire.AppendVarint(b, mapi.valWiretag)
b = protowire.AppendVarint(b, uint64(valSize))
return f.mi.marshalAppendPointer(b, val, opts)
before := len(b)
b, err = f.mi.marshalAppendPointer(b, val, opts)
if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
}
return b, err
}
}

View File

@@ -26,6 +26,15 @@ func sizeMessageSet(mi *MessageInfo, p pointer, opts marshalOptions) (size int)
}
num, _ := protowire.DecodeTag(xi.wiretag)
size += messageset.SizeField(num)
if fullyLazyExtensions(opts) {
// Don't expand the extension, instead use the buffer to calculate size
if lb := x.lazyBuffer(); lb != nil {
// We got hold of the buffer, so it's still lazy.
// Don't count the tag size in the extension buffer, it's already added.
size += protowire.SizeTag(messageset.FieldMessage) + len(lb) - xi.tagsize
continue
}
}
size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts)
}
@@ -85,6 +94,19 @@ func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts ma
xi := getExtensionFieldInfo(x.Type())
num, _ := protowire.DecodeTag(xi.wiretag)
b = messageset.AppendFieldStart(b, num)
if fullyLazyExtensions(opts) {
// Don't expand the extension if it's still in wire format, instead use the buffer content.
if lb := x.lazyBuffer(); lb != nil {
// The tag inside the lazy buffer is a different tag (the extension
// number), but what we need here is the tag for FieldMessage:
b = protowire.AppendVarint(b, protowire.EncodeTag(messageset.FieldMessage, protowire.BytesType))
b = append(b, lb[xi.tagsize:]...)
b = messageset.AppendFieldEnd(b)
return b, nil
}
}
b, err := xi.funcs.marshal(b, x.Value(), protowire.EncodeTag(messageset.FieldMessage, protowire.BytesType), opts)
if err != nil {
return b, err

View File

@@ -14,7 +14,7 @@ import (
// unwrapper unwraps the value to the underlying value.
// This is implemented by List and Map.
type unwrapper interface {
protoUnwrap() interface{}
protoUnwrap() any
}
// A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.

View File

@@ -136,6 +136,6 @@ func (ls *listReflect) NewElement() protoreflect.Value {
func (ls *listReflect) IsValid() bool {
return !ls.v.IsNil()
}
func (ls *listReflect) protoUnwrap() interface{} {
func (ls *listReflect) protoUnwrap() any {
return ls.v.Interface()
}

Some files were not shown because too many files have changed in this diff Show More