diff --git a/clientconn.go b/clientconn.go index 6a44343aea00..346fc0e14512 100644 --- a/clientconn.go +++ b/clientconn.go @@ -117,12 +117,11 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires }, nil } -// newClient returns a new client in idle mode. -func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) { +func newClient(target, defaultScheme string, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ target: target, conns: make(map[*addrConn]struct{}), - dopts: defaultDialOptions(), + dopts: defaultDialOptions(defaultScheme), czData: new(channelzData), } @@ -191,6 +190,11 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) return cc, nil } +// NewClient returns a new client in idle mode. +func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) { + return newClient(target, "dns", opts...) +} + // DialContext creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be // established, and connecting happens in the background). To make it a blocking @@ -208,7 +212,8 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) // https://github.com/grpc/grpc/blob/master/doc/naming.md. // e.g. to use dns resolver, a "dns:///" prefix should be applied to the target. func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { - cc, err := newClient(target, opts...) + // At the end of this method, we kick the channel out of idle, rather than waiting for the first rpc. + cc, err := newClient(target, "passthrough", opts...) if err != nil { return nil, err } @@ -1740,8 +1745,13 @@ func (cc *ClientConn) parseTargetAndFindResolver() error { // We are here because the user's dial target did not contain a scheme or // specified an unregistered scheme. We should fallback to the default // scheme, except when a custom dialer is specified in which case, we should - // always use passthrough scheme. - defScheme := resolver.GetDefaultScheme() + // always use passthrough scheme. For either case, we need to respect any overridden + // global defaults set by the user. + defScheme := cc.dopts.defScheme + if internal.UserSetDefaultScheme { + defScheme = resolver.GetDefaultScheme() + } + channelz.Infof(logger, cc.channelzID, "fallback to scheme %q", defScheme) canonicalTarget := defScheme + ":///" + cc.target diff --git a/clientconn_parsed_target_test.go b/clientconn_parsed_target_test.go index 1ff46aaf08c7..abb80611eae4 100644 --- a/clientconn_parsed_target_test.go +++ b/clientconn_parsed_target_test.go @@ -28,34 +28,87 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/resolver" ) +func generateTarget(scheme string, target string) resolver.Target { + return resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", scheme, target))} +} + +// This is here just in case another test calls the SetDefaultScheme method. +func resetInitialResolverState() { + resolver.SetDefaultScheme("passthrough") + internal.UserSetDefaultScheme = false +} + func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) { - defScheme := resolver.GetDefaultScheme() + resetInitialResolverState() + dialScheme := resolver.GetDefaultScheme() + newClientScheme := "dns" tests := []struct { - target string - wantParsed resolver.Target + target string + wantDialParse resolver.Target + wantNewClientParse resolver.Target }{ // No scheme is specified. - {target: "://a/b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "://a/b"))}}, - {target: "a//b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a//b"))}}, + { + target: "://a/b", + wantDialParse: generateTarget(dialScheme, "://a/b"), + wantNewClientParse: generateTarget(newClientScheme, "://a/b"), + }, + { + target: "a//b", + wantDialParse: generateTarget(dialScheme, "a//b"), + wantNewClientParse: generateTarget(newClientScheme, "a//b"), + }, // An unregistered scheme is specified. - {target: "a:///", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a:///"))}}, - {target: "a:b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a:b"))}}, + { + target: "a:///", + wantDialParse: generateTarget(dialScheme, "a:///"), + wantNewClientParse: generateTarget(newClientScheme, "a:///"), + }, + { + target: "a:b", + wantDialParse: generateTarget(dialScheme, "a:b"), + wantNewClientParse: generateTarget(newClientScheme, "a:b"), + }, // A registered scheme is specified. - {target: "dns://a.server.com/google.com", wantParsed: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")}}, - {target: "unix-abstract:/ a///://::!@#$%25^&*()b", wantParsed: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")}}, - {target: "unix-abstract:passthrough:abc", wantParsed: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")}}, - {target: "passthrough:///unix:///a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")}}, + { + target: "dns://a.server.com/google.com", + wantDialParse: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")}, + wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")}, + }, + { + target: "unix-abstract:/ a///://::!@#$%25^&*()b", + wantDialParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")}, + wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")}, + }, + { + target: "unix-abstract:passthrough:abc", + wantDialParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")}, + wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")}, + }, + { + target: "passthrough:///unix:///a/b/c", + wantDialParse: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")}, + wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")}, + }, // Cases for `scheme:absolute-path`. - {target: "dns:/a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")}}, - {target: "unregistered:/a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "unregistered:/a/b/c"))}}, + { + target: "dns:/a/b/c", + wantDialParse: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")}, + wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")}, + }, + { + target: "unregistered:/a/b/c", + wantDialParse: generateTarget(dialScheme, "unregistered:/a/b/c"), + wantNewClientParse: generateTarget(newClientScheme, "unregistered:/a/b/c"), + }, } for _, test := range tests { @@ -66,8 +119,18 @@ func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) { } defer cc.Close() - if !cmp.Equal(cc.parsedTarget, test.wantParsed) { - t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantParsed) + if !cmp.Equal(cc.parsedTarget, test.wantDialParse) { + t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantDialParse) + } + + cc, err = NewClient(test.target, WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", test.target, err) + } + defer cc.Close() + + if !cmp.Equal(cc.parsedTarget, test.wantNewClientParse) { + t.Errorf("cc.parsedTarget for newClient target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantNewClientParse) } }) } @@ -93,6 +156,7 @@ func (s) TestParsedTarget_Failure_WithoutCustomDialer(t *testing.T) { } func (s) TestParsedTarget_WithCustomDialer(t *testing.T) { + resetInitialResolverState() defScheme := resolver.GetDefaultScheme() tests := []struct { target string diff --git a/credentials/google/google_test.go b/credentials/google/google_test.go index 1809d545d0ec..12c151ba5428 100644 --- a/credentials/google/google_test.go +++ b/credentials/google/google_test.go @@ -24,9 +24,9 @@ import ( "testing" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" ) @@ -109,7 +109,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { { name: "with non-CFE cluster name", ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ - Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, + Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, }), // non-CFE backends should use alts. wantTyp: "alts", @@ -117,7 +117,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { { name: "with CFE cluster name", ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ - Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes, + Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes, }), // CFE should use tls. wantTyp: "tls", @@ -125,7 +125,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { { name: "with xdstp CFE cluster name", ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ - Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, + Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, }), // CFE should use tls. wantTyp: "tls", @@ -133,7 +133,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { { name: "with xdstp non-CFE cluster name", ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ - Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, + Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, }), // non-CFE should use atls. wantTyp: "alts", diff --git a/credentials/google/xds.go b/credentials/google/xds.go index 2c5c8b9eee13..cccb22271ee5 100644 --- a/credentials/google/xds.go +++ b/credentials/google/xds.go @@ -25,7 +25,7 @@ import ( "strings" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/xds" ) const cfeClusterNamePrefix = "google_cfe_" @@ -63,7 +63,7 @@ func clusterName(ctx context.Context) string { if chi.Attributes == nil { return "" } - cluster, _ := internal.GetXDSHandshakeClusterName(chi.Attributes) + cluster, _ := xds.GetXDSHandshakeClusterName(chi.Attributes) return cluster } diff --git a/credentials/google/xds_test.go b/credentials/google/xds_test.go index 8aeba396a518..b62e7a73bc1d 100644 --- a/credentials/google/xds_test.go +++ b/credentials/google/xds_test.go @@ -23,15 +23,15 @@ import ( "testing" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" ) func (s) TestIsDirectPathCluster(t *testing.T) { c := func(cluster string) context.Context { return icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ - Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes, + Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes, }) } diff --git a/dialoptions.go b/dialoptions.go index a95f86a1f4bf..19e3de6d4398 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -79,6 +79,7 @@ type dialOptions struct { resolvers []resolver.Builder idleTimeout time.Duration recvBufferPool SharedBufferPool + defScheme string } // DialOption configures how we set up the connection. @@ -631,7 +632,7 @@ func withHealthCheckFunc(f internal.HealthChecker) DialOption { }) } -func defaultDialOptions() dialOptions { +func defaultDialOptions(defScheme string) dialOptions { return dialOptions{ copts: transport.ConnectOptions{ ReadBufferSize: defaultReadBufSize, @@ -643,6 +644,7 @@ func defaultDialOptions() dialOptions { healthCheckFunc: internal.HealthCheckFunc, idleTimeout: 30 * time.Minute, recvBufferPool: nopBufferPool{}, + defScheme: defScheme, } } diff --git a/internal/internal.go b/internal/internal.go index 5082fdaba961..48d24bdb4e69 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -197,6 +197,9 @@ var ( // FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD. FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool) + + // UserSetDefaultScheme is set to true if the user has overridden the default resolver scheme. + UserSetDefaultScheme bool = false ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/xds_handshake_cluster.go b/internal/xds/xds.go similarity index 89% rename from internal/xds_handshake_cluster.go rename to internal/xds/xds.go index e8b492774d1a..024c388b7aa7 100644 --- a/internal/xds_handshake_cluster.go +++ b/internal/xds/xds.go @@ -14,7 +14,9 @@ * limitations under the License. */ -package internal +// Package xds contains methods to Get/Set handshake cluster names. It is separated +// out from the top level /internal package to avoid circular dependencies. +package xds import ( "google.golang.org/grpc/attributes" diff --git a/resolver/resolver.go b/resolver/resolver.go index 95c46452464e..202854511b81 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/serviceconfig" ) @@ -63,16 +64,18 @@ func Get(scheme string) Builder { } // SetDefaultScheme sets the default scheme that will be used. The default -// default scheme is "passthrough". +// scheme is initially set to "passthrough". // // NOTE: this function must only be called during initialization time (i.e. in // an init() function), and is not thread-safe. The scheme set last overrides // previously set values. func SetDefaultScheme(scheme string) { defaultScheme = scheme + internal.UserSetDefaultScheme = true } -// GetDefaultScheme gets the default scheme that will be used. +// GetDefaultScheme gets the default scheme that will be used by grpc.Dial. If +// SetDefaultScheme is never called, the default scheme used by grpc.NewClient is "dns" instead. func GetDefaultScheme() string { return defaultScheme } diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 1edf3b8b857a..d7221c32a81a 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -32,11 +32,11 @@ import ( "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/grpctest" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/testutils/fakeclient" @@ -464,7 +464,7 @@ func (s) TestClusterNameInAddressAttributes(t *testing.T) { if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } - cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes) + cn, ok := xds.GetXDSHandshakeClusterName(addrs1[0].Attributes) if !ok || cn != testClusterName { t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName) } @@ -495,7 +495,7 @@ func (s) TestClusterNameInAddressAttributes(t *testing.T) { t.Fatalf("sc is created with addr %v, want %v", got, want) } // New addresses should have the new cluster name. - cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes) + cn2, ok := xds.GetXDSHandshakeClusterName(addrs2[0].Attributes) if !ok || cn2 != testClusterName2 { t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2) } diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 407d2deff7d6..bee5d2c97a33 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -31,12 +31,12 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" xdsinternal "google.golang.org/grpc/xds/internal" @@ -359,7 +359,7 @@ func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer newAddrs := make([]resolver.Address, len(addrs)) var lID xdsinternal.LocalityID for i, addr := range addrs { - newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + newAddrs[i] = xds.SetXDSHandshakeClusterName(addr, clusterName) lID = xdsinternal.GetLocalityID(newAddrs[i]) } var sc balancer.SubConn @@ -384,7 +384,7 @@ func (b *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resol newAddrs := make([]resolver.Address, len(addrs)) var lID xdsinternal.LocalityID for i, addr := range addrs { - newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + newAddrs[i] = xds.SetXDSHandshakeClusterName(addr, clusterName) lID = xdsinternal.GetLocalityID(newAddrs[i]) } if scw, ok := sc.(*scWrapper); ok {