diff --git a/pkg/ext-proc/health.go b/pkg/ext-proc/health.go new file mode 100644 index 00000000..7a3cf8ef --- /dev/null +++ b/pkg/ext-proc/health.go @@ -0,0 +1,53 @@ +package main + +import ( + "context" + "fmt" + + "google.golang.org/grpc/codes" + healthPb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" + "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" + klog "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type healthServer struct { + client.Client +} + +func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckRequest) (*healthPb.HealthCheckResponse, error) { + if err := s.checkResources(); err != nil { + klog.Infof("gRPC health check not serving: %s", in.String()) + return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_NOT_SERVING}, nil + } + klog.Infof("gRPC health check serving: %s", in.String()) + return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil +} + +func (s *healthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Health_WatchServer) error { + return status.Error(codes.Unimplemented, "Watch is not implemented") +} + +// checkResources uses a client to list all InferenceModels in the configured namespace +// and gets the configured InferencePool by name and namespace. +func (s *healthServer) checkResources() error { + ctx := context.Background() + var infPool v1alpha1.InferencePool + if err := s.Client.Get( + ctx, + client.ObjectKey{Name: *poolName, Namespace: *poolNamespace}, + &infPool, + ); err != nil { + return fmt.Errorf("failed to get InferencePool %s/%s: %v", *poolNamespace, *poolName, err) + } + klog.Infof("Successfully retrieved InferencePool %s/%s", *poolNamespace, *poolName) + + var modelList v1alpha1.InferenceModelList + if err := s.Client.List(ctx, &modelList, client.InNamespace(*poolNamespace)); err != nil { + return fmt.Errorf("failed to list InferenceModels in namespace %s: %v", *poolNamespace, err) + } + klog.Infof("Found %d InferenceModels in namespace %s", len(modelList.Items), *poolNamespace) + + return nil +} diff --git a/pkg/ext-proc/main.go b/pkg/ext-proc/main.go index 3ef10074..52dceb27 100644 --- a/pkg/ext-proc/main.go +++ b/pkg/ext-proc/main.go @@ -1,20 +1,14 @@ package main import ( - "context" "flag" "fmt" "net" - "os" - "os/signal" - "syscall" "time" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "google.golang.org/grpc" - "google.golang.org/grpc/codes" healthPb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/status" "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend/vllm" @@ -28,10 +22,14 @@ import ( ) var ( - port = flag.Int( - "port", + grpcPort = flag.Int( + "grpcPort", 9002, - "gRPC port") + "The gRPC port used for communicating with Envoy proxy") + grpcHealthPort = flag.Int( + "grpcHealthPort", + 9003, + "The port used for gRPC liveness and readiness probes") targetPodHeader = flag.String( "targetPodHeader", "target-pod", @@ -64,32 +62,22 @@ var ( scheme = runtime.NewScheme() ) -type healthServer struct{} - -func (s *healthServer) Check( - ctx context.Context, - in *healthPb.HealthCheckRequest, -) (*healthPb.HealthCheckResponse, error) { - klog.Infof("Handling grpc Check request + %s", in.String()) - return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil -} - -func (s *healthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Health_WatchServer) error { - return status.Error(codes.Unimplemented, "Watch is not implemented") -} - func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(v1alpha1.AddToScheme(scheme)) } func main() { - klog.InitFlags(nil) flag.Parse() ctrl.SetLogger(klog.TODO()) + // Validate flags + if err := validateFlags(); err != nil { + klog.Fatalf("flag validation failed: %v", err) + } + // Print all flag values flags := "Flags: " flag.VisitAll(func(f *flag.Flag) { @@ -97,22 +85,16 @@ func main() { }) klog.Info(flags) - klog.Infof("Listening on %q", fmt.Sprintf(":%d", *port)) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + // Create a new manager to manage controllers + mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{Scheme: scheme}) if err != nil { - klog.Fatalf("failed to listen: %v", err) + klog.Fatalf("failed to start manager: %v", err) } + // Create the data store used to cache watched resources datastore := backend.NewK8sDataStore() - mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{ - Scheme: scheme, - }) - if err != nil { - klog.Error(err, "unable to start manager") - os.Exit(1) - } - + // Create the controllers and register them with the manager if err := (&backend.InferencePoolReconciler{ Datastore: datastore, Scheme: mgr.GetScheme(), @@ -121,7 +103,7 @@ func main() { PoolNamespace: *poolNamespace, Record: mgr.GetEventRecorderFor("InferencePool"), }).SetupWithManager(mgr); err != nil { - klog.Error(err, "Error setting up InferencePoolReconciler") + klog.Fatalf("Error setting up InferencePoolReconciler: %v", err) } if err := (&backend.InferenceModelReconciler{ @@ -132,7 +114,7 @@ func main() { PoolNamespace: *poolNamespace, Record: mgr.GetEventRecorderFor("InferenceModel"), }).SetupWithManager(mgr); err != nil { - klog.Error(err, "Error setting up InferenceModelReconciler") + klog.Fatalf("Error setting up InferenceModelReconciler: %v", err) } if err := (&backend.EndpointSliceReconciler{ @@ -143,53 +125,122 @@ func main() { ServiceName: *serviceName, Zone: *zone, }).SetupWithManager(mgr); err != nil { - klog.Error(err, "Error setting up EndpointSliceReconciler") + klog.Fatalf("Error setting up EndpointSliceReconciler: %v", err) + } + + // Channel to handle error signals for goroutines + errChan := make(chan error, 1) + + // Start each component in its own goroutine + startControllerManager(mgr, errChan) + healthSvr := startHealthServer(mgr, errChan, *grpcHealthPort) + extProcSvr := startExternalProcessorServer( + errChan, + datastore, + *grpcPort, + *refreshPodsInterval, + *refreshMetricsInterval, + *targetPodHeader, + ) + + // Wait for first error from any goroutine + err = <-errChan + if err != nil { + klog.Errorf("goroutine failed: %v", err) + } else { + klog.Infof("Manager exited gracefully") } - errChan := make(chan error) + // Gracefully shutdown components + if healthSvr != nil { + klog.Info("Health server shutting down...") + healthSvr.GracefulStop() + } + if extProcSvr != nil { + klog.Info("Ext-proc server shutting down...") + extProcSvr.GracefulStop() + } + + klog.Info("All components stopped gracefully") +} + +// startControllerManager runs the controller manager in a goroutine. +func startControllerManager(mgr ctrl.Manager, errChan chan<- error) { go func() { + // Blocking and will return when shutdown is complete. if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { - klog.Error(err, "Error running manager") - errChan <- err + errChan <- fmt.Errorf("controller manager failed to start: %w", err) } + // Manager exited gracefully + klog.Info("Controller manager shutting down...") + errChan <- nil }() +} - s := grpc.NewServer() +// startHealthServer starts the gRPC health probe server in a goroutine. +func startHealthServer(mgr ctrl.Manager, errChan chan<- error, port int) *grpc.Server { + healthSvr := grpc.NewServer() + healthPb.RegisterHealthServer(healthSvr, &healthServer{Client: mgr.GetClient()}) - pp := backend.NewProvider(&vllm.PodMetricsClientImpl{}, datastore) - if err := pp.Init(*refreshPodsInterval, *refreshMetricsInterval); err != nil { - klog.Fatalf("failed to initialize: %v", err) - } - extProcPb.RegisterExternalProcessorServer( - s, - handlers.NewServer( - pp, - scheduling.NewScheduler(pp), - *targetPodHeader, - datastore)) - healthPb.RegisterHealthServer(s, &healthServer{}) - - klog.Infof("Starting gRPC server on port :%v", *port) - - // shutdown - var gracefulStop = make(chan os.Signal, 1) - signal.Notify(gracefulStop, syscall.SIGTERM) - signal.Notify(gracefulStop, syscall.SIGINT) go func() { - select { - case sig := <-gracefulStop: - klog.Infof("caught sig: %+v", sig) - os.Exit(0) - case err := <-errChan: - klog.Infof("caught error in controller: %+v", err) - os.Exit(0) + healthLis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + errChan <- fmt.Errorf("health server failed to listen: %w", err) } + klog.Infof("Health server listening on port: %d", port) + // Blocking and will return when shutdown is complete. + if serveErr := healthSvr.Serve(healthLis); serveErr != nil && serveErr != grpc.ErrServerStopped { + errChan <- fmt.Errorf("health server failed: %w", serveErr) + } }() + return healthSvr +} - err = s.Serve(lis) - if err != nil { - klog.Fatalf("Ext-proc failed with the err: %v", err) +// startExternalProcessorServer starts the Envoy external processor server in a goroutine. +func startExternalProcessorServer( + errChan chan<- error, + datastore *backend.K8sDatastore, + port int, + refreshPodsInterval, refreshMetricsInterval time.Duration, + targetPodHeader string, +) *grpc.Server { + extSvr := grpc.NewServer() + go func() { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + errChan <- fmt.Errorf("ext-proc server failed to listen: %w", err) + } + klog.Infof("Ext-proc server listening on port: %d", port) + + // Initialize backend provider + pp := backend.NewProvider(&vllm.PodMetricsClientImpl{}, datastore) + if err := pp.Init(refreshPodsInterval, refreshMetricsInterval); err != nil { + errChan <- fmt.Errorf("failed to initialize backend provider: %w", err) + } + + // Register ext_proc handlers + extProcPb.RegisterExternalProcessorServer( + extSvr, + handlers.NewServer(pp, scheduling.NewScheduler(pp), targetPodHeader, datastore), + ) + + // Blocking and will return when shutdown is complete. + if serveErr := extSvr.Serve(lis); serveErr != nil && serveErr != grpc.ErrServerStopped { + errChan <- fmt.Errorf("ext-proc server failed: %w", serveErr) + } + }() + return extSvr +} + +func validateFlags() error { + if *poolName == "" { + return fmt.Errorf("required %q flag not set", "poolName") + } + + if *serviceName == "" { + return fmt.Errorf("required %q flag not set", "serviceName") } + return nil } diff --git a/pkg/manifests/ext_proc.yaml b/pkg/manifests/ext_proc.yaml index baa04d60..5fbd86a9 100644 --- a/pkg/manifests/ext_proc.yaml +++ b/pkg/manifests/ext_proc.yaml @@ -28,7 +28,6 @@ roleRef: kind: ClusterRole name: pod-read --- - apiVersion: apps/v1 kind: Deployment metadata: @@ -57,9 +56,25 @@ spec: - "3" - -serviceName - "vllm-llama2-7b-pool" + - -grpcPort + - "9002" + - -grpcHealthPort + - "9003" ports: - containerPort: 9002 - + - containerPort: 9003 + livenessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + readinessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 - name: curl image: curlimages/curl command: ["sleep", "3600"]