diff --git a/cmd/cmd.go b/cmd/cmd.go index 36d4af08..01e68ba3 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1060,6 +1060,32 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { return nil } +func completionFlagHandler(cmd *cobra.Command, _ []string, shells []string, completionFlag string) { + shl, err := cmd.Flags().GetString(completionFlag) + if err != nil { + fmt.Println("required one argument ") + return + } + if !slices.Contains(shells, shl) { + fmt.Println("argument was not one of " + strings.Join(shells, ", ")) + return + } + switch shl { + case shells[0]: + cmd.Root().GenBashCompletion(os.Stdout) + case shells[1]: + cmd.Root().GenZshCompletion(os.Stdout) + case shells[2]: + cmd.Root().GenFishCompletion(os.Stdout, true) + case shells[3]: + cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) + default: + fmt.Println("not a recognized shell") + os.Exit(1) + } + os.Exit(0) +} + func versionHandler(cmd *cobra.Command, _ []string) { client, err := api.ClientFromEnvironment() if err != nil { @@ -1103,6 +1129,9 @@ func NewCLI() *cobra.Command { console.ConsoleFromFile(os.Stdin) //nolint:errcheck } + var shells = []string{"bash", "zsh", "fish", "powershell"} + completionFlag := "shell-completion" + rootCmd := &cobra.Command{ Use: "ollama", Short: "Large language model runner", @@ -1117,11 +1146,20 @@ func NewCLI() *cobra.Command { return } + if compShell, _ := cmd.Flags().GetString(completionFlag); compShell != "" { + completionFlagHandler(cmd, args, shells, completionFlag) + return + } + cmd.Print(cmd.UsageString()) }, } rootCmd.Flags().BoolP("version", "v", false, "Show version information") + rootCmd.Flags().StringP(completionFlag, "", "", " Generate shell completions for "+strings.Join(shells, ", ")) + rootCmd.RegisterFlagCompletionFunc(completionFlag, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return shells, cobra.ShellCompDirectiveDefault + }) createCmd := &cobra.Command{ Use: "create MODEL",