Skip to content
This repository has been archived by the owner on Sep 19, 2022. It is now read-only.

Commit

Permalink
Merge pull request #151 from johnugeorge/crdchanges
Browse files Browse the repository at this point in the history
Implement ActiveDeadlineSeconds and BackoffLimit
  • Loading branch information
richardsliu authored Mar 27, 2019
2 parents 261dd72 + 511af4c commit e8d4d04
Show file tree
Hide file tree
Showing 11 changed files with 538 additions and 23 deletions.
5 changes: 3 additions & 2 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions pkg/apis/pytorch/v1beta2/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ type PyTorchJob struct {

// PyTorchJobSpec is a desired state description of the PyTorchJob.
type PyTorchJobSpec struct {
// Specifies the duration in seconds relative to the startTime that the job may be active
// before the system tries to terminate it; value must be positive integer.
// This method applies only to pods with restartPolicy == OnFailure or Always.
// +optional
ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"`

// Optional number of retries before marking this job failed.
// +optional
BackoffLimit *int32 `json:"backoffLimit,omitempty"`

// CleanPodPolicy defines the policy to kill pods after PyTorchJob is
// succeeded.
// Default to Running.
Expand Down
10 changes: 10 additions & 0 deletions pkg/apis/pytorch/v1beta2/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions pkg/common/util/v1beta2/testutil/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package testutil
import (
"time"

"github.com/golang/protobuf/proto"
"k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

Expand Down Expand Up @@ -50,9 +51,42 @@ func NewPyTorchJobWithCleanupJobDelay(master, worker int, ttl *int32) *v1beta2.P
return job
}

func NewPyTorchJobWithActiveDeadlineSeconds(master, worker int, ads *int64) *v1beta2.PyTorchJob {
if master == 1 {
job := NewPyTorchJobWithMaster(worker)
job.Spec.ActiveDeadlineSeconds = ads
policy := common.CleanPodPolicyAll
job.Spec.CleanPodPolicy = &policy
return job
}
job := NewPyTorchJob(worker)
job.Spec.ActiveDeadlineSeconds = ads
policy := common.CleanPodPolicyAll
job.Spec.CleanPodPolicy = &policy
return job
}

func NewPyTorchJobWithBackoffLimit(master, worker int, backoffLimit *int32) *v1beta2.PyTorchJob {
if master == 1 {
job := NewPyTorchJobWithMaster(worker)
job.Spec.BackoffLimit = backoffLimit
job.Spec.PyTorchReplicaSpecs["Worker"].RestartPolicy = "OnFailure"
policy := common.CleanPodPolicyAll
job.Spec.CleanPodPolicy = &policy
return job
}
job := NewPyTorchJob(worker)
job.Spec.BackoffLimit = backoffLimit
job.Spec.PyTorchReplicaSpecs["Worker"].RestartPolicy = "OnFailure"
policy := common.CleanPodPolicyAll
job.Spec.CleanPodPolicy = &policy
return job
}

func NewPyTorchJobWithMaster(worker int) *v1beta2.PyTorchJob {
job := NewPyTorchJob(worker)
job.Spec.PyTorchReplicaSpecs[v1beta2.PyTorchReplicaTypeMaster] = &common.ReplicaSpec{
Replicas: proto.Int32(1),
Template: NewPyTorchReplicaSpecTemplate(),
}
return job
Expand Down
7 changes: 5 additions & 2 deletions pkg/common/util/v1beta2/testutil/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,18 @@ func NewPodList(count int32, status v1.PodPhase, job *v1beta2.PyTorchJob, typ st
return pods
}

func SetPodsStatuses(podIndexer cache.Indexer, job *v1beta2.PyTorchJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, t *testing.T) {
func SetPodsStatuses(podIndexer cache.Indexer, job *v1beta2.PyTorchJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) {
var index int32
for _, pod := range NewPodList(pendingPods, v1.PodPending, job, typ, index, t) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", job.Name, err)
}
}
index += pendingPods
for _, pod := range NewPodList(activePods, v1.PodRunning, job, typ, index, t) {
for i, pod := range NewPodList(activePods, v1.PodRunning, job, typ, index, t) {
if restartCounts != nil {
pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", job.Name, err)
}
Expand Down
124 changes: 113 additions & 11 deletions pkg/controller.v1beta2/pytorch/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package pytorch

import (
"fmt"
"strings"
"time"

kubebatchclient "github.com/kubernetes-sigs/kube-batch/pkg/client/clientset/versioned"
Expand All @@ -38,8 +39,10 @@ import (
jobinformers "github.com/kubeflow/pytorch-operator/pkg/client/informers/externalversions"
jobinformersv1beta2 "github.com/kubeflow/pytorch-operator/pkg/client/informers/externalversions/pytorch/v1beta2"
joblisters "github.com/kubeflow/pytorch-operator/pkg/client/listers/pytorch/v1beta2"
common "github.com/kubeflow/tf-operator/pkg/apis/common/v1beta2"
"github.com/kubeflow/tf-operator/pkg/common/jobcontroller"
pylogger "github.com/kubeflow/tf-operator/pkg/logger"
"github.com/kubeflow/tf-operator/pkg/util/k8sutil"
)

const (
Expand Down Expand Up @@ -326,18 +329,15 @@ func (pc *PyTorchController) syncPyTorchJob(key string) (bool, error) {
return true, err
}

func getTotalReplicas(obj metav1.Object) int32 {
job := obj.(*v1beta2.PyTorchJob)
jobReplicas := int32(0)
for _, r := range job.Spec.PyTorchReplicaSpecs {
jobReplicas += *r.Replicas
}
return jobReplicas
}

// reconcilePyTorchJobs checks and updates replicas for each given PyTorchReplicaSpec.
// It will requeue the job in case of an error while creating/deleting pods/services.
func (pc *PyTorchController) reconcilePyTorchJobs(job *v1beta2.PyTorchJob) error {
jobKey, err := KeyFunc(job)
if err != nil {
utilruntime.HandleError(fmt.Errorf("couldn't get key for pytorch job object %#v: %v", job, err))
return err
}

logger := pylogger.LoggerForJob(job)
logger.Infof("Reconcile PyTorchJobs %s", job.Name)

Expand All @@ -355,8 +355,46 @@ func (pc *PyTorchController) reconcilePyTorchJobs(job *v1beta2.PyTorchJob) error
return err
}

// retrieve the previous number of retry
previousRetry := pc.WorkQueue.NumRequeues(jobKey)

activePods := k8sutil.FilterActivePods(pods)
active := int32(len(activePods))
failed := int32(k8sutil.FilterPods(pods, v1.PodFailed))
totalReplicas := getTotalReplicas(job)
prevReplicasFailedNum := getTotalFailedReplicas(job)

var failureMessage string
jobExceedsLimit := false
exceedsBackoffLimit := false
pastBackoffLimit := false

if job.Spec.BackoffLimit != nil {
jobHasNewFailure := failed > prevReplicasFailedNum
// new failures happen when status does not reflect the failures and active
// is different than parallelism, otherwise the previous controller loop
// failed updating status so even if we pick up failure it is not a new one
exceedsBackoffLimit = jobHasNewFailure && (active != totalReplicas) &&
(int32(previousRetry)+1 > *job.Spec.BackoffLimit)

pastBackoffLimit, err = pc.pastBackoffLimit(job, pods)
if err != nil {
return err
}
}

if exceedsBackoffLimit || pastBackoffLimit {
// check if the number of pod restart exceeds backoff (for restart OnFailure only)
// OR if the number of failed jobs increased since the last syncJob
jobExceedsLimit = true
failureMessage = fmt.Sprintf("PyTorchJob %s has failed because it has reached the specified backoff limit", job.Name)
} else if pc.pastActiveDeadline(job) {
failureMessage = fmt.Sprintf("PyTorchJob %s has failed because it was active longer than specified deadline", job.Name)
jobExceedsLimit = true
}

// If the PyTorchJob is terminated, delete all pods and services.
if isSucceeded(job.Status) || isFailed(job.Status) {
if isSucceeded(job.Status) || isFailed(job.Status) || jobExceedsLimit {
if err := pc.deletePodsAndServices(job, pods); err != nil {
return err
}
Expand All @@ -375,7 +413,18 @@ func (pc *PyTorchController) reconcilePyTorchJobs(job *v1beta2.PyTorchJob) error

}
}

if jobExceedsLimit {
pc.Recorder.Event(job, v1.EventTypeNormal, pytorchJobFailedReason, failureMessage)
if job.Status.CompletionTime == nil {
now := metav1.Now()
job.Status.CompletionTime = &now
}
err := updatePyTorchJobConditions(job, common.JobFailed, pytorchJobFailedReason, failureMessage)
if err != nil {
logger.Infof("Append pytorchjob condition error: %v", err)
return err
}
}
// At this point the pods may have been deleted, so if the job succeeded, we need to manually set the replica status.
// If any replicas are still Active, set their status to succeeded.
if isSucceeded(job.Status) {
Expand Down Expand Up @@ -434,6 +483,59 @@ func (pc *PyTorchController) satisfiedExpectations(job *v1beta2.PyTorchJob) bool
return satisfied
}

// pastBackoffLimitOnFailure checks if container restartCounts sum exceeds BackoffLimit
// this method applies only to pods with restartPolicy == OnFailure or Always
func (pc *PyTorchController) pastBackoffLimit(job *v1beta2.PyTorchJob, pods []*v1.Pod) (bool, error) {
if job.Spec.BackoffLimit == nil {
return false, nil
}
logger := pylogger.LoggerForJob(job)
result := int32(0)
for rtype, spec := range job.Spec.PyTorchReplicaSpecs {
if spec.RestartPolicy != common.RestartPolicyOnFailure && spec.RestartPolicy != common.RestartPolicyAlways {
logger.Warnf("The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit.", rtype, job.Name)
continue
}
// Convert PyTorchReplicaType to lower string.
rt := strings.ToLower(string(rtype))
pods, err := pc.FilterPodsForReplicaType(pods, rt)
if err != nil {
return false, err
}
for i := range pods {
po := pods[i]
if po.Status.Phase != v1.PodRunning {
continue
}
for j := range po.Status.InitContainerStatuses {
stat := po.Status.InitContainerStatuses[j]
result += stat.RestartCount
}
for j := range po.Status.ContainerStatuses {
stat := po.Status.ContainerStatuses[j]
result += stat.RestartCount
}
}
}

if *job.Spec.BackoffLimit == 0 {
return result > 0, nil
}
return result >= *job.Spec.BackoffLimit, nil
}

// pastActiveDeadline checks if job has ActiveDeadlineSeconds field set and if it is exceeded.
func (pc *PyTorchController) pastActiveDeadline(job *v1beta2.PyTorchJob) bool {
if job.Spec.ActiveDeadlineSeconds == nil || job.Status.StartTime == nil {
return false
}
now := metav1.Now()
start := job.Status.StartTime.Time
duration := now.Time.Sub(start)
allowedDuration := time.Duration(*job.Spec.ActiveDeadlineSeconds) * time.Second
return duration >= allowedDuration
}

func (pc *PyTorchController) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
return pc.getPyTorchJobFromName(namespace, name)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller.v1beta2/pytorch/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ func TestNormalPath(t *testing.T) {
}

podIndexer := kubeInformerFactory.Core().V1().Pods().Informer().GetIndexer()
testutil.SetPodsStatuses(podIndexer, job, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, t)
testutil.SetPodsStatuses(podIndexer, job, testutil.LabelMaster, tc.pendingMasterPods, tc.activeMasterPods, tc.succeededMasterPods, tc.failedMasterPods, t)
testutil.SetPodsStatuses(podIndexer, job, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, nil, t)
testutil.SetPodsStatuses(podIndexer, job, testutil.LabelMaster, tc.pendingMasterPods, tc.activeMasterPods, tc.succeededMasterPods, tc.failedMasterPods, nil, t)

serviceIndexer := kubeInformerFactory.Core().V1().Services().Informer().GetIndexer()
testutil.SetServices(serviceIndexer, job, testutil.LabelWorker, tc.activeWorkerServices, t)
Expand Down
45 changes: 45 additions & 0 deletions pkg/controller.v1beta2/pytorch/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,37 @@ func (pc *PyTorchController) updatePyTorchJob(old, cur interface{}) {
if err != nil {
return
}
curPyTorchJob, err := jobFromUnstructured(cur)
if err != nil {
return
}

// never return error
key, err := KeyFunc(curPyTorchJob)
if err != nil {
return
}

log.Infof("Updating pytorchjob: %s", oldPyTorchJob.Name)
pc.enqueuePyTorchJob(cur)

// check if need to add a new rsync for ActiveDeadlineSeconds
if curPyTorchJob.Status.StartTime != nil {
curPyTorchJobADS := curPyTorchJob.Spec.ActiveDeadlineSeconds
if curPyTorchJobADS == nil {
return
}
oldPyTorchJobADS := oldPyTorchJob.Spec.ActiveDeadlineSeconds
if oldPyTorchJobADS == nil || *oldPyTorchJobADS != *curPyTorchJobADS {
now := metav1.Now()
start := curPyTorchJob.Status.StartTime.Time
passed := now.Time.Sub(start)
total := time.Duration(*curPyTorchJobADS) * time.Second
// AddAfter will handle total < passed
pc.WorkQueue.AddAfter(key, total-passed)
log.Infof("job ActiveDeadlineSeconds updated, will rsync after %d seconds", total-passed)
}
}
}

func (pc *PyTorchController) deletePodsAndServices(job *v1beta2.PyTorchJob, pods []*v1.Pod) error {
Expand Down Expand Up @@ -160,3 +189,19 @@ func (pc *PyTorchController) cleanupPyTorchJob(job *v1beta2.PyTorchJob) error {
func (pc *PyTorchController) deletePyTorchJob(job *v1beta2.PyTorchJob) error {
return pc.jobClientSet.KubeflowV1beta2().PyTorchJobs(job.Namespace).Delete(job.Name, &metav1.DeleteOptions{})
}

func getTotalReplicas(job *v1beta2.PyTorchJob) int32 {
jobReplicas := int32(0)
for _, r := range job.Spec.PyTorchReplicaSpecs {
jobReplicas += *r.Replicas
}
return jobReplicas
}

func getTotalFailedReplicas(job *v1beta2.PyTorchJob) int32 {
totalFailedReplicas := int32(0)
for rtype := range job.Status.ReplicaStatuses {
totalFailedReplicas += job.Status.ReplicaStatuses[rtype].Failed
}
return totalFailedReplicas
}
Loading

0 comments on commit e8d4d04

Please sign in to comment.