diff --git a/cmd/kagi/main.go b/cmd/kagi/main.go index 8b3ac21..22f7b7b 100644 --- a/cmd/kagi/main.go +++ b/cmd/kagi/main.go @@ -1,11 +1,14 @@ package main import ( + "bufio" + "bytes" "errors" "flag" "fmt" "log" "os" + "strconv" "strings" "github.com/bcspragu/kagi/api" @@ -47,28 +50,68 @@ func run(args []string) error { return fmt.Errorf("error performing query: %w", err) } - response := respond(resp, query) + response, err := respond(resp, query) + if err != nil { + return fmt.Errorf("failed to build response: %w", err) + } fmt.Print(response) return nil } -func respond(resp *api.FastGPTResponse, query string) (response string) { - // remove all repeated newlines or empty lines from the output - answer := strings.ReplaceAll(resp.Data.Output, "\n\n", "\n") - - response = "# " + query + "\n" + answer + "\n" +func respond(resp *api.FastGPTResponse, query string) (string, error) { + var buf bytes.Buffer + buf.WriteString("# ") + buf.WriteString(query) + buf.WriteRune('\n') + if err := streamAndRemoveDoubleNewlines(resp.Data.Output, &buf); err != nil { + return "", fmt.Errorf("failed to remove double newlines: %w", err) + } + buf.WriteRune('\n') // If there are no references, return early if len(resp.Data.References) == 0 { - return + return buf.String(), nil } - response += "\n# References\n" + buf.WriteString("\n# References\n") for i, ref := range resp.Data.References { - response += fmt.Sprintf("%d. %s - %s - %s\n", i+1, ref.Title, ref.Link, ref.Snippet) + // fmt.Sprintf("%d. %s - %s - %s\n", i+1, ref.Title, ref.Link, ref.Snippet) + buf.WriteString(strconv.Itoa(i + 1)) + buf.WriteString(". ") + buf.WriteString(ref.Title) + buf.WriteString(" - ") + buf.WriteString(ref.Link) + buf.WriteString(" - ") + buf.WriteString(ref.Snippet) + buf.WriteRune('\n') } - return + return buf.String(), nil +} + +// Remove all repeated newlines or empty lines from the given string +func streamAndRemoveDoubleNewlines(inp string, buf *bytes.Buffer) error { + r := strings.NewReader(inp) + + sc := bufio.NewScanner(r) + + first := true + for sc.Scan() { + if sc.Text() == "" { + continue + } + if !first { + buf.WriteRune('\n') + } + first = false + buf.WriteString(sc.Text()) + } + + if err := sc.Err(); err != nil { + return fmt.Errorf("error reading input: %w", err) + } + + return nil } diff --git a/cmd/kagi/main_test.go b/cmd/kagi/main_test.go new file mode 100644 index 0000000..9bca9fb --- /dev/null +++ b/cmd/kagi/main_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestStreamAndRemoveDoubleNewlines(t *testing.T) { + inp := `Hello + +I am a string with + +some double newlines + + +And a triple newline!` + + var buf bytes.Buffer + if err := streamAndRemoveDoubleNewlines(inp, &buf); err != nil { + t.Fatalf("streamAndRemoveDoubleNewlines: %v", err) + } + + want := `Hello +I am a string with +some double newlines +And a triple newline!` + + got := buf.String() + + if got != want { + t.Errorf("unexpected output = %q, want %q", got, want) + } +}