diff --git a/tools/identify_license/backend/backend.go b/tools/identify_license/backend/backend.go index b1976bf..a9e46ba 100644 --- a/tools/identify_license/backend/backend.go +++ b/tools/identify_license/backend/backend.go @@ -65,37 +65,57 @@ func (b *ClassifierBackend) Close() { // ClassifyLicenses runs the license classifier over the given file. func (b *ClassifierBackend) ClassifyLicenses(filenames []string, headers bool) (errors []error) { - // Create a pool from which tasks can later be started. We use a pool because the OS limits - // the number of files that can be open at any one time. - const numTasks = 1000 - task := make(chan bool, numTasks) - for i := 0; i < numTasks; i++ { - task <- true - } + return b.ClassifyLicensesWithContext(context.Background(), filenames, headers) +} +// ClassifyLicensesWithContext runs the license classifier over the given file; +// ensure that it will respect the timeout and cancelation in the provided context. +func (b *ClassifierBackend) ClassifyLicensesWithContext(ctx context.Context, filenames []string, headers bool) (errors []error) { + + files := make(chan string, len(filenames)) + for _, f := range filenames { + files <- f + } + close(files) errs := make(chan error, len(filenames)) var wg sync.WaitGroup - analyze := func(filename string) { - defer func() { - task <- true - wg.Done() + + // Create a pool from which tasks can later be started. We use a pool because the OS limits + // the number of files that can be open at any one time. + const numTasks = 1000 + wg.Add(numTasks) + + for i := 0; i < numTasks; i++ { + go func() { + // Ensure that however this function terminates, the wait group + // is unblocked + defer wg.Done() + + for { + filename := <-files + + // no file? we're done + if filename == "" { + break + } + + // If the context is done, record that the file was not + // classified due to the context's termination. + if err := ctx.Err(); err != nil { + errs <- fmt.Errorf("file %s not classified due to context completion: %v", filename, err) + continue + } + + if err := b.classifyLicense(filename, headers); err != nil { + errs <- err + } + } }() - if err := b.classifyLicense(filename, headers); err != nil { - errs <- err - } } - for _, filename := range filenames { - wg.Add(1) - <-task - go analyze(filename) - } - go func() { - wg.Wait() - close(task) - close(errs) - }() + wg.Wait() + close(errs) for err := range errs { errors = append(errors, err) @@ -103,23 +123,6 @@ func (b *ClassifierBackend) ClassifyLicenses(filenames []string, headers bool) ( return errors } -// ClassifyLicensesWithContext runs the license classifier over the given file; ensure that it will respect the timeout in the provided context. -func (b *ClassifierBackend) ClassifyLicensesWithContext(ctx context.Context, filenames []string, headers bool) (errors []error) { - done := make(chan bool) - go func() { - errors = b.ClassifyLicenses(filenames, headers) - done <- true - }() - select { - case <-ctx.Done(): - err := ctx.Err() - errors = append(errors, err) - return errors - case <-done: - return errors - } -} - // classifyLicense is called by a Go-function to perform the actual // classification of a license. func (b *ClassifierBackend) classifyLicense(filename string, headers bool) error {