Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group flags #1778

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ type Command struct {
iflags *flag.FlagSet
// parentsPflags is all persistent flags of cmd's parents.
parentsPflags *flag.FlagSet
// lnamedFlagSets contains local named flags.
lnamedFlagSets *NamedFlagSets
// globNormFunc is the global normalization function
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName
Expand Down Expand Up @@ -514,7 +516,8 @@ Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "he
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}

Flags:
{{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}}
{{.LocalNonNamedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableNamedFlags}}
{{.NamedFlagSets.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}}

Global Flags:
{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}}
Expand Down Expand Up @@ -578,6 +581,7 @@ func stripFlags(args []string, c *Command) []string {
return args
}
c.mergePersistentFlags()
c.mergeNamedFlags()

commands := []string{}
flags := c.Flags()
Expand Down Expand Up @@ -1046,6 +1050,7 @@ func (c *Command) validateRequiredFlags() error {
// If c already has help flag, it will do nothing.
func (c *Command) InitDefaultHelpFlag() {
c.mergePersistentFlags()
c.mergeNamedFlags()
if c.Flags().Lookup("help") == nil {
usage := "help for "
if c.Name() == "" {
Expand All @@ -1067,6 +1072,7 @@ func (c *Command) InitDefaultVersionFlag() {
}

c.mergePersistentFlags()
c.mergeNamedFlags()
if c.Flags().Lookup("version") == nil {
usage := "version for "
if c.Name() == "" {
Expand Down Expand Up @@ -1475,6 +1481,31 @@ func (c *Command) Flags() *flag.FlagSet {
return c.flags
}

// NamedFlagSets returns all the named FlagSet that applies to this command.
func (c *Command) NamedFlagSets() *NamedFlagSets {
if c.lnamedFlagSets == nil {
c.lnamedFlagSets = NewNamedFlagSets(c.Name(), flag.ContinueOnError)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the flat.ContinueOnError be hardcoded?

}
return c.lnamedFlagSets
}

// NamedFlags returns the specific named FlagSet that applies to this command.
func (c *Command) NamedFlags(name string) *flag.FlagSet {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NamedFlags doesn't support persistent flags at the moment.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any idea of what should be done to add this support?

nfs := c.NamedFlagSets()
flags, ok := nfs.FlagSet(name)
if !ok {
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
flags.SetOutput(c.flagErrorBuf)
if c.globNormFunc != nil {
flags.SetNormalizeFunc(c.globNormFunc)
}
}

return flags
}

// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands.
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
persistentFlags := c.PersistentFlags()
Expand All @@ -1488,9 +1519,23 @@ func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
return out
}

// LocalNonNamedFlags are flags specific to this command which are NOT named.
func (c *Command) LocalNonNamedFlags() *flag.FlagSet {
namedFlags := c.NamedFlagSets().Flatten()

out := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.LocalFlags().VisitAll(func(f *flag.Flag) {
if namedFlags.Lookup(f.Name) == nil {
out.AddFlag(f)
}
})
return out
}

// LocalFlags returns the local FlagSet specifically set in the current command.
func (c *Command) LocalFlags() *flag.FlagSet {
c.mergePersistentFlags()
c.mergeNamedFlags()

if c.lflags == nil {
c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
Expand All @@ -1517,6 +1562,7 @@ func (c *Command) LocalFlags() *flag.FlagSet {
// InheritedFlags returns all flags which were inherited from parent commands.
func (c *Command) InheritedFlags() *flag.FlagSet {
c.mergePersistentFlags()
c.mergeNamedFlags()

if c.iflags == nil {
c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
Expand Down Expand Up @@ -1568,6 +1614,7 @@ func (c *Command) ResetFlags() {
c.lflags = nil
c.iflags = nil
c.parentsPflags = nil
c.lnamedFlagSets = nil
}

// HasFlags checks if the command contains any flags (local plus persistent from the entire structure).
Expand Down Expand Up @@ -1596,6 +1643,11 @@ func (c *Command) HasAvailableFlags() bool {
return c.Flags().HasAvailableFlags()
}

// HasAvailableNamedFlags checks if the command contains any named flags which are not hidden or deprecated.
func (c *Command) HasAvailableNamedFlags() bool {
return c.NamedFlagSets().Flatten().HasAvailableFlags()
}

// HasAvailablePersistentFlags checks if the command contains persistent flags which are not hidden or deprecated.
func (c *Command) HasAvailablePersistentFlags() bool {
return c.PersistentFlags().HasAvailableFlags()
Expand Down Expand Up @@ -1648,6 +1700,7 @@ func (c *Command) ParseFlags(args []string) error {
}
beforeErrorBufLen := c.flagErrorBuf.Len()
c.mergePersistentFlags()
c.mergeNamedFlags()

// do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
Expand All @@ -1674,6 +1727,11 @@ func (c *Command) mergePersistentFlags() {
c.Flags().AddFlagSet(c.parentsPflags)
}

// mergeNamedFlags merges c.NamedFlagSets() to c.Flags()
func (c *Command) mergeNamedFlags() {
c.Flags().AddFlagSet(c.NamedFlagSets().Flatten())
}

// updateParentsPflags updates c.parentsPflags by adding
// new persistent flags of all parents.
// If c.parentsPflags == nil, it makes new.
Expand Down
89 changes: 89 additions & 0 deletions named_flag_sets.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package cobra
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this file be in spf13/pflag?


import (
"bytes"
"fmt"
"strings"

"github.com/spf13/pflag"
)

// NamedFlagSets stores named flag sets in the order of calling FlagSet.
type NamedFlagSets struct {
name string
errorHandling pflag.ErrorHandling

// order is an ordered list of flag set names.
order []string
// flagSets stores the flag sets by name.
flagSets map[string]*pflag.FlagSet
}

func NewNamedFlagSets(name string, errorHandling pflag.ErrorHandling) *NamedFlagSets {
return &NamedFlagSets{
name: name,
errorHandling: errorHandling,
}
}

// FlagSet returns the flag set with the given name and adds it to the
// ordered name list if it is not in there yet.
func (nfs *NamedFlagSets) FlagSet(name string) (*pflag.FlagSet, bool) {
if nfs.flagSets == nil {
nfs.flagSets = map[string]*pflag.FlagSet{}
}
Comment on lines +32 to +34

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this operation be done in NewNamedFlagSets()?

var ok bool
if _, ok = nfs.flagSets[name]; !ok {
flagSet := pflag.NewFlagSet(name, nfs.errorHandling)
nfs.flagSets[name] = flagSet
nfs.order = append(nfs.order, name)
}
return nfs.flagSets[name], ok
}

// Flatten returns a single flag set containing all the flag sets
// in the NamedFlagSet
func (nfs *NamedFlagSets) Flatten() *pflag.FlagSet {
out := pflag.NewFlagSet(nfs.name, nfs.errorHandling)
for _, fs := range nfs.flagSets {
out.AddFlagSet(fs)
}
return out
}

// FlagUsages returns a string containing the usage information for all flags in
// the FlagSet
func (nfs *NamedFlagSets) FlagUsages() string {
return nfs.FlagUsagesWrapped(0)
}

func (nfs *NamedFlagSets) FlagUsagesWrapped(cols int) string {
var buf bytes.Buffer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use a string builder?

for _, name := range nfs.order {
fs := nfs.flagSets[name]
if !fs.HasFlags() {
continue
}

wideFS := pflag.NewFlagSet("", pflag.ExitOnError)
wideFS.AddFlagSet(fs)

var zzz string
if cols > 24 {
zzz = strings.Repeat("z", cols-24)
wideFS.Int(zzz, 0, strings.Repeat("z", cols-24))
}

s := fmt.Sprintf("\n%s Flags:\n%s", strings.ToUpper(name[:1])+name[1:], wideFS.FlagUsagesWrapped(cols))

if cols > 24 {
i := strings.Index(s, zzz)
lines := strings.Split(s[:i], "\n")
fmt.Fprint(&buf, strings.Join(lines[:len(lines)-1], "\n"))
fmt.Fprintln(&buf)
Comment on lines +79 to +83

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code really needed? And why 24?

} else {
fmt.Fprint(&buf, s)
}
}
return buf.String()
}