package cobra

import (
	"bytes"
	"fmt"
	"io"
	"os"
	"strings"

	flag "github.com/spf13/pflag"
)

// GenZshCompletionFile generates zsh completion file.
func (c *Command) GenZshCompletionFile(filename string) error {
	outFile, err := os.Create(filename)
	if err != nil {
		return err
	}
	defer outFile.Close()

	return c.GenZshCompletion(outFile)
}

func argName(cmd *Command) string {
	for cmd.HasParent() {
		cmd = cmd.Parent()
	}
	name := fmt.Sprintf("%s_cmd_args", cmd.Name())
	return strings.Replace(name, "-", "_",-1)
}
// GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (c *Command) GenZshCompletion(w io.Writer) error {
	buf := new(bytes.Buffer)

	writeHeader(buf, c)
	maxDepth := maxDepth(c)
	fmt.Fprintf(buf, "_%s() {\n", c.Name())
	writeLevelMapping(buf, maxDepth, c)
	writeLevelCases(buf, maxDepth, c)
	fmt.Fprintf(buf, "}\n_%s \"$@\"\n", c.Name())

	_, err := buf.WriteTo(w)
	return err
}

func writeHeader(w io.Writer, cmd *Command) {
	fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}

func maxDepth(c *Command) int {
	if len(c.Commands()) == 0 {
		return 0
	}
	maxDepthSub := 0
	for _, s := range c.Commands() {
		subDepth := maxDepth(s)
		if subDepth > maxDepthSub {
			maxDepthSub = subDepth
		}
	}
	return 1 + maxDepthSub
}

func writeLevelMapping(w io.Writer, numLevels int, root *Command) {
	fmt.Fprintln(w, `local context curcontext="$curcontext" state line`)
	fmt.Fprintln(w, `typeset -A opt_args`)
	fmt.Fprintln(w, `_arguments -C  \`)
	for i := 1; i <= numLevels; i++ {
		fmt.Fprintf(w, "  '%d: :->level%d' \\\n", i, i)
	}
	fmt.Fprintf(w, "  $%s \\\n", argName(root))
	fmt.Fprintln(w, `  '*: :_files'`)
}

func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
	fmt.Fprintln(w, "case $state in")
	for i := 1; i <= maxDepth; i++ {
		writeLevel(w, root, i)
	}
	fmt.Fprintln(w, "  *)")
	fmt.Fprintln(w, "    _arguments '*: :_files'")
	fmt.Fprintln(w, "  ;;")
	fmt.Fprintln(w, "esac")
}

func writeLevel(w io.Writer, root *Command, l int) {
	fmt.Fprintf(w, "  level%d)\n", l)
	fmt.Fprintf(w, "    case $words[%d] in\n", l)
	for _, c := range filterByLevel(root, l) {
		writeCommandArgsBlock(w, c)
	}
	fmt.Fprintln(w, "      *)")
	fmt.Fprintln(w, "        _arguments '*: :_files'")
	fmt.Fprintln(w, "      ;;")
	fmt.Fprintln(w, "    esac")
	fmt.Fprintln(w, "  ;;")
}

func writeCommandArgsBlock(w io.Writer, c *Command) {
	names := commandNames(c)
	flags := commandFlags(c)
	if len(names) > 0 || len(flags) > 0 {
		fmt.Fprintf(w, "      %s)\n", c.Name())
		defer fmt.Fprintln(w, "      ;;")
	}
	if len(flags) > 0 {
        fmt.Fprintf(w, "        %s=(\n", argName(c))
		for _, flag := range flags {
			fmt.Fprintf(w, "            %s\n", flag)
		}
		fmt.Fprintln(w, "        )")
	}
	if len(names) > 0 {
		fmt.Fprintf(w, "        _values 'command' '%s'\n", strings.Join(names, "' '"))
	}
}

func filterByLevel(c *Command, l int) []*Command {
	commands := []*Command{c}
	for i := 1; i < l; i++ {
		var nextLevel []*Command
		for _, c := range commands {
			if c.HasSubCommands() {
				nextLevel = append(nextLevel, c.Commands()...)
			}
		}
		commands = nextLevel
	}

	return commands
}

func commandNames(command *Command) []string {
	commands := command.Commands()
	ns := make([]string, len(commands))
	for i, c := range commands {
        commandMsg := c.Name()
        if len(c.Short) > 0 {
            commandMsg += fmt.Sprintf("[%s]", c.Short)
        }
		ns[i] = commandMsg
	}
	return ns
}

func commandFlags(command *Command) []string {
	flags := command.Flags()
	ns := make([]string, 0)
	flags.VisitAll(func(flag *flag.Flag) {
        var flagMsg string
		if len(flag.Shorthand) > 0 {
			flagMsg = fmt.Sprintf("{-%s,--%s}", flag.Shorthand, flag.Name)
		} else {
			flagMsg = fmt.Sprintf("--%s", flag.Name)
		}
        if len(flag.Usage) > 0 {
            flagMsg += fmt.Sprintf("'[%s]'", flag.Usage)
        }
        ns = append(ns, flagMsg)
	})
	return ns
}