package main

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"path/filepath"
	"runtime"
	"strconv"
	"strings"

	_ "git.sr.ht/~charles/rq/builtins"
	rqio "git.sr.ht/~charles/rq/io"
	"git.sr.ht/~charles/rq/util"
	"github.com/open-policy-agent/opa/v1/ast"
	opabundle "github.com/open-policy-agent/opa/v1/bundle"
	"github.com/open-policy-agent/opa/v1/rego"
	"github.com/open-policy-agent/opa/v1/storage"
	"github.com/open-policy-agent/opa/v1/storage/inmem"
	"github.com/open-policy-agent/opa/v1/topdown/print"
)

const whitespace = " \t\r"

// resolveEscapes silently tries to resolve escape codes in fields where it
// is appropriate, such as --csv-comma and --csv-comment.
func resolveEscapes(common *Common) {
	if len(common.CSVComma) > 1 {
		c, err := util.Unescape(common.CSVComma)
		// We only overwrite on length 1 so as to avoid printing a
		// mutated version of the CSVComma field in errors later.
		if (err == nil) && (len(c) == 1) {
			common.CSVComma = c
		}
	}

	if len(common.CSVComment) > 1 {
		c, err := util.Unescape(common.CSVComment)
		if (err == nil) && (len(c) == 1) {
			common.CSVComment = c
		}
	}
}

func defaultInputOptions(common *Common) (*rqio.DataSpec, error) {
	ds := &rqio.DataSpec{
		Options: make(map[string]string),
	}

	resolveEscapes(common)

	if len(common.CSVComma) > 1 {
		return nil, fmt.Errorf("--csv-comma '%s' must be exactly one character", common.CSVComma)
	}

	if len(common.CSVComment) > 1 {
		return nil, fmt.Errorf("--csv-comment '%s' must be exactly zero one characters", common.CSVComment)
	}

	var CSVComment rune = 0
	if len(common.CSVComment) == 1 {
		CSVComment = []rune(common.CSVComment)[0]
	}

	if len(common.CSVComma) == 1 {
		CSVComma := []rune(common.CSVComma)[0]
		ds.Options["csv.comma"] = util.ValueToString(CSVComma)
	}

	ds.Format = common.InputFormat

	ds.Options["csv.comment"] = util.ValueToString(CSVComment)
	ds.Options["csv.skip_lines"] = util.ValueToString(common.CSVSkipLines)
	ds.Options["csv.headers"] = util.ValueToString(common.CSVHeaders)
	ds.Options["csv.infer"] = util.ValueToString(!common.CSVNoInfer)

	return ds, nil
}

// patchIn inserts the given data into the given object at the given path,
// overwriting what is already there if anything.
func patchIn(path []string, obj map[string]interface{}, data interface{}) map[string]interface{} {
	if len(path) == 0 {
		return obj
	}

	key := path[0]
	path = path[1:]
	v, ok := obj[key]

	// If this is the last path element, we can direct insert it into the
	// object that exists (by assumption that obj should not be nil).
	if len(path) == 0 {
		obj[key] = data
		return obj
	}

	// If the key is present, it needs to be an object
	if ok {
		subObj, subObjOk := v.(map[string]interface{})
		if subObjOk {
			// The value is a nested object, so we don't need to
			// make a new object.
			obj[key] = patchIn(path, subObj, data)
			return obj
		}

	}

	// If we make it here, the key either does not exist, or is primitive
	// and needs to be overwritten.
	obj[key] = patchIn(path, make(map[string]interface{}), data)
	return obj
}

// loadData loads all the data specified in the list of data specs. The option
// "rego.path" is used to determine where under "data" the object should be
// loaded. If it is not specified, then the path is simply the file basename,
// sans extension.
//
// In the event that the same Rego path is specified for multiple data paths,
// the one with the highest list index takes precedence.
//
// If the format is not specified explicitly, then this function checks the
// file's extension and tries to determine the format from that. If this fails,
// then the format is assumed to be JSON.
func loadData(dataPaths []string, common *Common) (map[string]interface{}, error) {
	result := make(map[string]interface{})
	defaults, err := defaultInputOptions(common)
	if err != nil {
		return nil, err
	}

	for _, path := range dataPaths {
		ds, err := rqio.ParseDataSpec(path)
		if err != nil {
			return nil, err
		}

		base := filepath.Base(ds.FilePath)
		defaults.Options["rego.path"] = strings.TrimSuffix(base, filepath.Ext(base))
		ds.ResolveDefaults(defaults)

		parsed, err := rqio.LoadInput(ds)
		if err != nil {
			return nil, err
		}

		regoPath := ds.Options["rego.path"]

		// NOTE: this might break with Rego paths that contain .,
		// should investigate and fix this if needed.
		result = patchIn(strings.Split(regoPath, "."), result, parsed)
	}

	return result, nil
}

type printHook struct{}

func (h printHook) Print(ctx print.Context, msg string) error {
	fmt.Fprintln(os.Stderr, msg)
	return nil
}

// runRegoQuery runs a rego Query on the given input data, returning the
// resulting Go object.
func runRegoQuery(inputValue interface{}, query string, data map[string]interface{}, bundle *opabundle.Bundle, regoVersion int) (interface{}, error) {
	astValue, err := ast.InterfaceToValue(inputValue)
	if err != nil {
		return nil, err
	}

	// Merge the data from the bundle with the provided data.
	//
	// This is a performance optimization - in rq query, often the
	// bundle will be nil and the data will be non-nil.
	if bundle != nil {
		bundle = mergeBundles([]*opabundle.Bundle{bundle}, data, &regoVersion)
	}

	regoArgs := []func(*rego.Rego){}
	evalArgs := []rego.EvalOption{}

	// Prepare storage with any needed input data.
	s := inmem.New()
	if bundle != nil {
		s = inmem.NewFromObject(bundle.Data)
	} else if data != nil {
		s = inmem.NewFromObject(data)
	}

	// Set up the query from the user, as well as the input data.
	regoArgs = append(regoArgs, rego.Query(query))
	regoArgs = append(regoArgs, rego.ParsedInput(astValue))
	regoArgs = append(regoArgs, rego.Imports([]string{"future.keywords.in"}))
	regoArgs = append(regoArgs, rego.EnablePrintStatements(true))
	regoArgs = append(regoArgs, rego.PrintHook(printHook{}))
	regoArgs = append(regoArgs, rego.Store(s))
	if bundle != nil {
		txn, err := s.NewTransaction(context.Background(), storage.WriteParams)
		if err != nil {
			return nil, err
		}

		regoArgs = append(regoArgs, rego.Transaction(txn))
		regoArgs = append(regoArgs, rego.ParsedBundle("", bundle))

	}

	if regoVersion == 0 {
		regoArgs = append(regoArgs, rego.SetRegoVersion(ast.RegoV0))
	}

	// Prepare the query by inserting our arguments (including the query
	// and input data).
	r := rego.New(regoArgs...)
	pq, err := r.PrepareForEval(context.Background())
	if err != nil {
		return nil, err
	}

	// Actually execute the query and get the result back.
	results, err := pq.Eval(context.Background(), evalArgs...)
	if err != nil {
		return nil, err
	}

	// The results are actually a ResultSet, which can contain
	// multiple results. Each of those results can contain multiple
	// expressions which include metadata as well as the actual value. We
	// want to unpick that into something that will give us a sensible
	// output when marshaled as JSON. We also want to make sure that
	// in cases where we have only a single value, we don't needlessly
	// wrap it in arrays.
	//
	// Result Set --+----> Result 1 --+--> Bindings
	//              |                 |
	//              |                 +--> Expression 1.1 --+--> *Value
	//              |                 |                     |
	//              |                 |                     +--> Text
	//              |                 |                     |
	//              |                 |                     +--> Location
	//              |                 |
	//              |                 +--> Expression 1.2 -----> *Value
	//              |                 ⋮
	//              |
	//              +----> Result 2 --+--> Expression 2.1 -----> *Value
	//              |                 |
	//              ⋮                 ⋮
	//
	// In the diagram above, the values we actually care about are marked
	// with "*" symbols. Only result 1 and expression 1.1 show the fully
	// expanded data model for results and expressions, respectively.
	var outputData interface{}

	if len(results) == 1 {
		// In the common case, there will only be one result.
		if len(results[0].Expressions) == 1 {
			// If there is only one result and it has only one
			// expression, just output that value directly
			//
			// (Observe if we didn't do this, you'd end up with
			// [[value]] as the output)
			outputData = results[0].Expressions[0].Value
		} else {

			// If there are multiple results, build them up into a
			// list (but still don't wrap that in the outer list).
			outputList := []interface{}{}
			for _, e := range results[0].Expressions {
				outputList = append(outputList, e.Value)
			}
			outputData = outputList
		}

	} else if len(results) == 0 {
		// Typically, this means that the expression evaluated to
		// undefined.
		outputData = nil

	} else {

		// We have multiple results. In this case, we smack them all
		// in a list. Just like previously, if there is only one value
		// for a given result, don't wrap it in a list, but if there
		// are multiple, wrap them in a list.
		outputList := []interface{}{}
		for _, r := range results {
			if len(r.Expressions) == 1 {
				outputList = append(outputList, r.Expressions[0].Value)
			} else {
				valueList := []interface{}{}
				for _, e := range r.Expressions {
					valueList = append(valueList, e.Value)
				}
				outputList = append(outputList, valueList)
			}
		}
		outputData = outputList
	}

	return outputData, nil
}

func defaultOutputOptions(common *Common, outputIsTTY bool) (*rqio.DataSpec, error) {
	ds := &rqio.DataSpec{
		Options: make(map[string]string),
	}

	resolveEscapes(common)

	if len(common.CSVComma) > 1 {
		return nil, fmt.Errorf("--csv-comma '%s' must be exactly one character", common.CSVComma)
	}

	// By default, we color if the output is a terminal, and don't if
	// it isn't. If --no-color and --force-color are both specified, then
	// the --no-color takes precedence. The NO_COLOR environment variable
	// (see: https://no-color.org/) has the highest priority.
	colorize := false
	if outputIsTTY {
		colorize = true
	}
	if common.NoColor {
		colorize = false
	} else if common.ForceColor {
		colorize = true
	}
	if _, nocolor := os.LookupEnv("NO_COLOR"); nocolor {
		colorize = false
	}

	if runtime.GOOS == "windows" {
		if _, suppress := os.LookupEnv("RQ_SUPPRESS_WINDOWS_NO_COLOR"); !suppress {
			colorize = false
		}
	}

	ds.Format = common.OutputFormat

	ds.Options["output.colorize"] = util.ValueToString(colorize)
	ds.Options["output.pretty"] = util.ValueToString(!common.Ugly)
	ds.Options["output.style"] = util.ValueToString(common.Style)

	if common.Template != "" {
		ds.Options["output.template"] = common.Template
		ds.Format = "template"
	}

	if len(common.CSVComma) == 1 {
		CSVComma := []rune(common.CSVComma)[0]
		ds.Options["csv.comma"] = util.ValueToString(CSVComma)
	}

	// Indent may have special characters like \t in it, so we should
	// unescape them.
	indentUnescaped, err := util.Unescape(common.Indent)
	if err != nil {
		return nil, err
	}
	ds.Options["output.indent"] = util.ValueToString(indentUnescaped)

	return ds, nil
}

// helpfulAdvice checks for common failure modes and generates a helpful error
// message in advance.
func helpfulAdvice(query string, common *Common) error {

	//////// `rq somefile` when `rq < somefile` was intended

	// Check if the query has a syntax error.
	r := rego.New(rego.Query(query))
	if _, err := r.PrepareForEval(context.Background()); err != nil {
		// Check if the query is an extant file.
		if _, err := os.Stat(query); err == nil {
			return fmt.Errorf("TIP: query '%s' is not valid Rego syntax, but is an extant filepath, did you mean 'rq < /some/file'?", query)
		}
	}

	if _, err := rqio.SelectInputHandler(common.InputFormat); err != nil {
		//////// `rq -i somefile` when `rq -I somefile` was meant
		if _, err := os.Stat(common.InputFormat); err == nil {
			return fmt.Errorf("TIP: '%s' is not a valid input format, but is an extant filepath, did you mean '-I/--input'?", common.InputFormat)
		}

		//////// `rq -i fmt:opt=val:path` when `rq -I fmt:opt=val:path` was meant
		if strings.ContainsAny(common.InputFormat, "=:{}") {
			return fmt.Errorf("TIP: '%s' is not a valid input format, but looks like it might be a dataspec, did you mean '-I/--input'?", common.InputFormat)
		}
	}

	if _, err := rqio.SelectOutputHandler(common.OutputFormat); err != nil {
		//////// `rq -o fmt:opt=val:path` when `rq -O fmt:opt=val:path` was meant
		if strings.ContainsAny(common.OutputFormat, "=:{}") {
			return fmt.Errorf("TIP: '%s' is not a valid output format, but looks like it might be a dataspec, did you mean '-O/--output'?", common.OutputFormat)
		}
	}

	return nil
}

func getScalar(dict map[string][]string, key string, defVal string) (string, bool) {
	list, ok := dict[key]
	if !ok {
		return defVal, false
	}

	if len(list) < 1 {
		return defVal, false
	}

	return list[len(list)-1], true
}

func getVector(dict map[string][]string, key string, defVal []string) ([]string, bool) {
	list, ok := dict[key]
	if !ok {
		return defVal, false
	}

	return list, true
}

// -> (value, ok)
func getBool(dict map[string][]string, key string, defVal bool) (bool, bool) {
	s, ok := getScalar(dict, key, util.ValueToString(defVal))
	if !ok {
		return defVal, false
	}
	return util.Truthy(s), true
}

func getInt(dict map[string][]string, key string, defVal int) (int, bool) {
	s, ok := getScalar(dict, key, util.ValueToString(defVal))
	if !ok {
		return 0, false
	}

	i, err := strconv.Atoi(s)
	if err != nil {
		return 0, false
	}
	return i, true
}

// parseScriptConfig parses out any configuration data embedded with in an rq
// script.
//
// This parses lines like:
//
// # rq: query input.stuff
// # rq: input.datasource csv:csv.infer=true:./data.csv
//
// To be parsed as a directive, the line must start with `\w#\wrq:` where `\w`
// is any arbitrary non-newline whitespace. The first space-delimited field
// after the marker is the parameter, and the remainder of the line is the
// value.
//
// For example:
//
//	#     rq: foo.bar.baz hello rq    world!
//
// Is parsed to have the parameter `foo.bar.baz`, and the value
// `hello rq    world!`
//
// Directives can be specified more than once. For directives that expect
// scalar values, the last (furthest down in the input) is the canonical one,
// by convention.
func parseScriptConfig(script string) (map[string][]string, error) {
	lines := strings.Split(script, "\n")
	directives := map[string][]string{}

	for i, line := range lines {
		splitComment := strings.SplitN(line, "#", 2)
		if len(splitComment) < 2 {
			continue
		}
		remainder := strings.Trim(splitComment[1], whitespace)

		splitMarker := strings.SplitN(remainder, ":", 2)
		if len(splitMarker) < 2 {
			continue
		}
		if strings.Trim(splitMarker[0], whitespace) != "rq" {
			continue
		}
		remainder = strings.Trim(splitMarker[1], whitespace)

		splitParameterValue := strings.SplitN(remainder, " ", 2)
		if len(splitParameterValue) != 2 {
			return nil, fmt.Errorf("rq script directives must have a parameter and a value, you only provided a parameter on line %d", i)
		}

		parameter := strings.Trim(splitParameterValue[0], whitespace)
		value := strings.Trim(splitParameterValue[1], whitespace)
		directives[parameter] = append(directives[parameter], value)
	}

	return directives, nil
}

// mergeBundles generates a new bundle containing the files and data from all
// of the provided bundles. Unlike OPA's bundle.Merge(), this function silently
// overwrites conflicting paths in the order of the bundle slice provided. That
// is in the case of a conflict, the conflicting module file whose bundle has
// the greatest index in the bundles slice takes precedence.
//
// The data argument, if non-nil, is applied last, meaning it takes precedence
// over all other bundles provided.
//
// This function considers the Data, Modules, WasmModule, and PlanModules
// fields of the bundles, but ignores all other fields.
//
// The provided regoVersion value is directly used in the merged bundle's
// manifest, irrespective of the rego version specified by any of
// the source bundles.
func mergeBundles(bundles []*opabundle.Bundle, data map[string]interface{}, regoVersion *int) *opabundle.Bundle {
	if len(bundles) == 0 && data == nil {
		return nil
	}

	if len(bundles) == 1 && data == nil {
		return bundles[0]
	}

	if len(bundles) == 1 && data == nil && bundles[0].Data == nil {
		bundles[0].Data = data
		return bundles[0]
	}

	mergedData := make(map[string]interface{})
	modules := make(map[string]opabundle.ModuleFile)
	wasmModules := make(map[string]opabundle.WasmModuleFile)
	planModules := make(map[string]opabundle.PlanModuleFile)

	for _, b := range bundles {
		for _, m := range b.Modules {
			modules[m.Path] = m
		}

		for _, w := range b.WasmModules {
			wasmModules[w.Path] = w
		}

		for _, p := range b.PlanModules {
			planModules[p.Path] = p
		}

		if b.Data != nil {
			for _, p := range util.Keyspace(util.MapPath{}, b.Data) {
				if v, ok := p.Access(b.Data); ok {
					mergedData = patchIn(p.Slice(), mergedData, v)
				}
			}
		}
	}

	if data != nil {
		for _, p := range util.Keyspace(util.MapPath{}, data) {
			if v, ok := p.Access(data); ok {
				mergedData = patchIn(p.Slice(), mergedData, v)
			}
		}
	}

	merged := &opabundle.Bundle{
		Manifest: opabundle.Manifest{
			Roots:       &[]string{},
			RegoVersion: regoVersion,
		},
		Data:        mergedData,
		Modules:     make([]opabundle.ModuleFile, len(modules)),
		WasmModules: make([]opabundle.WasmModuleFile, len(wasmModules)),
		PlanModules: make([]opabundle.PlanModuleFile, len(planModules)),
	}

	i := 0
	for _, m := range modules {
		merged.Modules[i] = m
		i++
	}

	i = 0
	for _, w := range wasmModules {
		merged.WasmModules[i] = w
		i++
	}

	i = 0
	for _, p := range planModules {
		merged.PlanModules[i] = p
		i++
	}

	return merged
}

func saveBundle(data map[string]interface{}, bundle *opabundle.Bundle, path string, regoVersion *int) error {
	bundle = mergeBundles([]*opabundle.Bundle{bundle}, data, regoVersion)
	if bundle == nil {
		return fmt.Errorf("failed to merge bundles")
	}

	f, err := os.Create(path)
	if err != nil {
		return err
	}
	defer func() { _ = f.Close() }()

	bundle.Manifest = opabundle.Manifest{RegoVersion: regoVersion}

	w := opabundle.NewWriter(f)
	return w.Write(*bundle)
}

// testCheck implements the check for -T/-test, to determine if the result of
// the query should cause a nonzero exit. A return value of true means that
// v was considered falsey.
func testCheck(v interface{}) bool {
	if v == nil {
		return true
	}

	// We really do have to know the type of an interface to see _which_
	// zero value we need to compare against. For example:
	// https://go.dev/play/p/bAXgeOOaw9w
	//
	// The list of all possible types OPA can return to the Go runtime
	// can be found here: https://github.com/open-policy-agent/opa/blob/89bb42b6a7724d1f6122f3eb3f4e6cf41e9bd10f/ast/term.go#L173
	//
	// This switch statement is a little excessive.
	switch v.(type) {
	case int:
		return v.(int) == 0
	case int8:
		return v.(int8) == 0
	case int16:
		return v.(int16) == 0
	case int32:
		// note that this masks rune
		return v.(int16) == 0
	case int64:
		return v.(int64) == 0
	case uint:
		return v.(uint) == 0
	case uint8:
		return v.(uint8) == 0
	case uint16:
		return v.(uint16) == 0
	case uint32:
		return v.(uint32) == 0
	case uint64:
		return v.(uint64) == 0
	case float32:
		return v.(float64) == 0
	case float64:
		return v.(float64) == 0
	case string:
		return v.(string) == ""
	case bool:
		return v.(bool) == false
	case json.Number:
		n, err := v.(json.Number).Int64()
		return (err != nil) || (n == 0)
	case map[string]any:
		return len(v.(map[string]any)) == 0
	case map[any]any:
		return len(v.(map[any]any)) == 0
	case map[string]string:
		return len(v.(map[string]string)) == 0
	case []string:
		return len(v.([]string)) == 0
	case []any:
		return len(v.([]any)) == 0
	default:
		panic(fmt.Sprintf("don't know how to check value of type '%T' returned from OPA: %+v\n", v, v))
	}
}
