Skip to content
Snippets Groups Projects
Commit 28f93e63 authored by John Cai's avatar John Cai
Browse files

Merge branch 'jv-sidechannel-client' into 'master'

client: add sidechannel support

Closes gitlab-com/gl-infra/scalability#1303

See merge request !3900
parents c2cf30c9 e64a6c21
No related branches found
No related tags found
1 merge request!3900client: add sidechannel support
Pipeline #382961359 failed
......@@ -4,6 +4,7 @@ import (
"context"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client"
"gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel"
"google.golang.org/grpc"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
......@@ -30,6 +31,15 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro
return DialContext(context.Background(), rawAddress, connOpts)
}
// DialSidechannel configures the dialer to establish a Gitaly
// backchannel connection instead of a regular gRPC connection. It also
// injects sr as a sidechannel registry, so that Gitaly can establish
// sidechannels back to the client.
func DialSidechannel(ctx context.Context, rawAddress string, sr *SidechannelRegistry, connOpts []grpc.DialOption) (*grpc.ClientConn, error) {
clientHandshaker := sidechannel.NewClientHandshaker(sr.logger, sr.registry)
return client.Dial(ctx, rawAddress, connOpts, clientHandshaker)
}
// FailOnNonTempDialError helps to identify if remote listener is ready to accept new connections.
func FailOnNonTempDialError() []grpc.DialOption {
return []grpc.DialOption{
......
......@@ -12,6 +12,7 @@ import (
"testing"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uber/jaeger-client-go"
......@@ -26,6 +27,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)
......@@ -37,7 +39,11 @@ func TestDial(t *testing.T) {
t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution")
}
stop, connectionMap := startListeners(t)
stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server {
srv := grpc.NewServer(grpc.Creds(creds))
healthpb.RegisterHealthServer(srv, &healthServer{})
return srv
})
defer stop()
unixSocketAbsPath := connectionMap["unix"]
......@@ -147,6 +153,110 @@ func TestDial(t *testing.T) {
}
}
func TestDialSidechannel(t *testing.T) {
if emitProxyWarning() {
t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution")
}
stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server {
return grpc.NewServer(TestSidechannelServer(newLogger(t), creds, func(
_ interface{},
stream grpc.ServerStream,
sidechannelConn io.ReadWriteCloser,
) error {
if method, ok := grpc.Method(stream.Context()); !ok || method != "/grpc.health.v1.Health/Check" {
return fmt.Errorf("unexpected method: %s", method)
}
var req healthpb.HealthCheckRequest
if err := stream.RecvMsg(&req); err != nil {
return fmt.Errorf("recv msg: %w", err)
}
if _, err := io.Copy(sidechannelConn, sidechannelConn); err != nil {
return fmt.Errorf("copy: %w", err)
}
if err := stream.SendMsg(&healthpb.HealthCheckResponse{}); err != nil {
return fmt.Errorf("send msg: %w", err)
}
return nil
})...)
})
defer stop()
unixSocketAbsPath := connectionMap["unix"]
tempDir := testhelper.TempDir(t)
unixSocketPath := filepath.Join(tempDir, "gitaly.socket")
require.NoError(t, os.Symlink(unixSocketAbsPath, unixSocketPath))
registry := NewSidechannelRegistry(newLogger(t))
tests := []struct {
name string
rawAddress string
envSSLCertFile string
dialOpts []grpc.DialOption
}{
{
name: "tcp sidechannel",
rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234"
},
{
name: "tls sidechannel",
rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234"
envSSLCertFile: "./testdata/gitalycert.pem",
},
{
name: "unix sidechannel",
rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envSSLCertFile != "" {
defer testhelper.ModifyEnvironment(t, gitalyx509.SSLCertFile, tt.envSSLCertFile)()
}
ctx, cancel := testhelper.Context()
defer cancel()
conn, err := DialSidechannel(ctx, tt.rawAddress, registry, tt.dialOpts)
require.NoError(t, err)
defer conn.Close()
ctx, scw := registry.Register(ctx, func(conn SidechannelConn) error {
const message = "hello world"
if _, err := io.WriteString(conn, message); err != nil {
return err
}
if err := conn.CloseWrite(); err != nil {
return err
}
buf, err := io.ReadAll(conn)
if err != nil {
return err
}
if string(buf) != message {
return fmt.Errorf("expected %q, got %q", message, buf)
}
return nil
})
defer scw.Close()
req := &healthpb.HealthCheckRequest{Service: "test sidechannel"}
_, err = healthpb.NewHealthClient(conn).Check(ctx, req)
require.NoError(t, err)
require.NoError(t, scw.Close())
})
}
}
type testSvc struct {
proxytestdata.UnimplementedTestServiceServer
PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error)
......@@ -414,26 +524,23 @@ func TestDial_Tracing(t *testing.T) {
}
// healthServer provide a basic GRPC health service endpoint for testing purposes
type healthServer struct{}
func (*healthServer) Check(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil
type healthServer struct {
healthpb.UnimplementedHealthServer
}
func (*healthServer) Watch(*healthpb.HealthCheckRequest, healthpb.Health_WatchServer) error {
return status.Errorf(codes.Unimplemented, "Not implemented")
func (*healthServer) Check(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
return &healthpb.HealthCheckResponse{}, nil
}
// startTCPListener will start a insecure TCP listener on a random unused port
func startTCPListener(t testing.TB) (func(), string) {
func startTCPListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) {
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
tcpPort := listener.Addr().(*net.TCPAddr).Port
address := fmt.Sprintf("%d", tcpPort)
grpcServer := grpc.NewServer()
healthpb.RegisterHealthServer(grpcServer, &healthServer{})
grpcServer := factory(insecure.NewCredentials())
go grpcServer.Serve(listener)
return func() {
......@@ -442,14 +549,13 @@ func startTCPListener(t testing.TB) (func(), string) {
}
// startUnixListener will start a unix socket listener using a temporary file
func startUnixListener(t testing.TB) (func(), string) {
func startUnixListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) {
serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t)
listener, err := net.Listen("unix", serverSocketPath)
require.NoError(t, err)
grpcServer := grpc.NewServer()
healthpb.RegisterHealthServer(grpcServer, &healthServer{})
grpcServer := factory(insecure.NewCredentials())
go grpcServer.Serve(listener)
return func() {
......@@ -459,7 +565,7 @@ func startUnixListener(t testing.TB) (func(), string) {
// startTLSListener will start a secure TLS listener on a random unused port
//go:generate openssl req -newkey rsa:4096 -new -nodes -x509 -days 3650 -out testdata/gitalycert.pem -keyout testdata/gitalykey.pem -subj "/C=US/ST=California/L=San Francisco/O=GitLab/OU=GitLab-Shell/CN=localhost" -addext "subjectAltName = IP:127.0.0.1, DNS:localhost"
func startTLSListener(t testing.TB) (func(), string) {
func startTLSListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) {
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
......@@ -469,11 +575,12 @@ func startTLSListener(t testing.TB) (func(), string) {
cert, err := tls.LoadX509KeyPair("testdata/gitalycert.pem", "testdata/gitalykey.pem")
require.NoError(t, err)
grpcServer := grpc.NewServer(grpc.Creds(credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
})))
healthpb.RegisterHealthServer(grpcServer, &healthServer{})
grpcServer := factory(
credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}),
)
go grpcServer.Serve(listener)
return func() {
......@@ -481,18 +588,18 @@ func startTLSListener(t testing.TB) (func(), string) {
}, address
}
var listeners = map[string]func(testing.TB) (func(), string){
var listeners = map[string]func(testing.TB, func(credentials.TransportCredentials) *grpc.Server) (func(), string){
"tcp": startTCPListener,
"unix": startUnixListener,
"tls": startTLSListener,
}
// startListeners will start all the different listeners used in this test
func startListeners(t testing.TB) (func(), map[string]string) {
func startListeners(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), map[string]string) {
var closers []func()
connectionMap := map[string]string{}
for k, v := range listeners {
closer, address := v(t)
closer, address := v(t, factory)
closers = append(closers, closer)
connectionMap[k] = address
}
......@@ -532,3 +639,5 @@ func TestHealthCheckDialer(t *testing.T) {
require.NoError(t, err)
require.NoError(t, cc.Close())
}
func newLogger(t testing.TB) *logrus.Entry { return logrus.NewEntry(testhelper.NewTestLogger(t)) }
package client
import (
"context"
"io"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// SidechannelRegistry associates sidechannel callbacks with outbound
// gRPC calls.
type SidechannelRegistry struct {
registry *sidechannel.Registry
logger *logrus.Entry
}
// NewSidechannelRegistry returns a new registry.
func NewSidechannelRegistry(logger *logrus.Entry) *SidechannelRegistry {
return &SidechannelRegistry{
registry: sidechannel.NewRegistry(),
logger: logger,
}
}
// Register registers a callback. It adds metadata to ctx and returns the
// new context. The caller must use the new context for the gRPC call.
// Caller must Close() the returned SidechannelWaiter to prevent resource
// leaks.
func (sr *SidechannelRegistry) Register(
ctx context.Context,
callback func(SidechannelConn) error,
) (context.Context, *SidechannelWaiter) {
ctx, waiter := sidechannel.RegisterSidechannel(
ctx,
sr.registry,
func(cc *sidechannel.ClientConn) error { return callback(cc) },
)
return ctx, &SidechannelWaiter{waiter: waiter}
}
// SidechannelWaiter represents a pending sidechannel and its callback.
type SidechannelWaiter struct{ waiter *sidechannel.Waiter }
// Close de-registers the sidechannel callback. If the callback is still
// running, Close blocks until it is done and returns the error return
// value of the callback. If the callback has not been called yet, Close
// returns an error immediately.
func (w *SidechannelWaiter) Close() error { return w.waiter.Close() }
// SidechannelConn allows a client to read and write bytes with less
// overhead than doing so via gRPC messages.
type SidechannelConn interface {
io.ReadWriter
// CloseWrite tells the server we won't write any more data. We can still
// read data from the server after CloseWrite(). A typical use case is in
// an RPC where the byte stream has a request/response pattern: the
// client then uses CloseWrite() to signal the end of the request body.
// When the client calls CloseWrite(), the server receives EOF.
CloseWrite() error
}
// TestSidechannelServer allows downstream consumers of this package to
// create mock sidechannel gRPC servers.
func TestSidechannelServer(
logger *logrus.Entry,
creds credentials.TransportCredentials,
handler func(interface{}, grpc.ServerStream, io.ReadWriteCloser) error,
) []grpc.ServerOption {
lm := listenmux.New(creds)
lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
return []grpc.ServerOption{
grpc.Creds(lm),
grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
conn, err := sidechannel.OpenSidechannel(stream.Context())
if err != nil {
return err
}
defer conn.Close()
return handler(srv, stream, conn)
}),
}
}
......@@ -11,16 +11,13 @@ import (
"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
gitalyauth "gitlab.com/gitlab-org/gitaly/v14/auth"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/git"
"gitlab.com/gitlab-org/gitaly/v14/internal/git/gittest"
"gitlab.com/gitlab-org/gitaly/v14/internal/git/pktline"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
......@@ -471,14 +468,9 @@ func makePostUploadPackRequest(ctx context.Context, t *testing.T, serverSocketPa
}
func dialSmartHTTPServerWithSidechannel(t *testing.T, serverSocketPath, token string, registry *sidechannel.Registry) *grpc.ClientConn {
logger := logrus.NewEntry(logrus.New())
t.Helper()
factory := func() backchannel.Server {
lm := listenmux.New(insecure.NewCredentials())
lm.Register(sidechannel.NewServerHandshaker(registry))
return grpc.NewServer(grpc.Creds(lm))
}
clientHandshaker := backchannel.NewClientHandshaker(logger, factory)
clientHandshaker := sidechannel.NewClientHandshaker(testhelper.DiscardTestEntry(t), registry)
connOpts := []grpc.DialOption{
grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())),
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token)),
......
......@@ -9,8 +9,13 @@ import (
"strconv"
"time"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
......@@ -124,3 +129,16 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf
func NewServerHandshaker(registry *Registry) *ServerHandshaker {
return &ServerHandshaker{registry: registry}
}
// NewClientHandshaker is used to enable sidechannel support on outbound
// gRPC connections.
func NewClientHandshaker(logger *logrus.Entry, registry *Registry) client.Handshaker {
return backchannel.NewClientHandshaker(
logger,
func() backchannel.Server {
lm := listenmux.New(insecure.NewCredentials())
lm.Register(NewServerHandshaker(registry))
return grpc.NewServer(grpc.Creds(lm))
},
)
}
......@@ -166,13 +166,7 @@ func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string
func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) {
registry := NewRegistry()
factory := func() backchannel.Server {
lm := listenmux.New(insecure.NewCredentials())
lm.Register(NewServerHandshaker(registry))
return grpc.NewServer(grpc.Creds(lm))
}
clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory)
clientHandshaker := NewClientHandshaker(newLogger(), registry)
dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials()))
conn, err := grpc.Dial(addr, dialOpt)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment