Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mapkon authored Nov 1, 2023
2 parents 2b971f8 + 4ac5dca commit ac75498
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 12 deletions.
2 changes: 2 additions & 0 deletions cmd/saml2aws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/versent/saml2aws/v2/cmd/saml2aws/commands"
"github.com/versent/saml2aws/v2/pkg/flags"
"github.com/versent/saml2aws/v2/pkg/prompter"
)

var (
Expand Down Expand Up @@ -46,6 +47,7 @@ func buildCmdList(s kingpin.Settings) (target *[]string) {
func main() {

log.SetOutput(os.Stderr)
prompter.SetOutputWriter(os.Stderr)
log.SetFlags(0)
logrus.SetOutput(os.Stderr)

Expand Down
27 changes: 21 additions & 6 deletions pkg/prompter/survey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,29 @@ package prompter
import (
"errors"
"fmt"
"os"

survey "github.com/AlecAivazis/survey/v2"
survey_terminal "github.com/AlecAivazis/survey/v2/terminal"
)

// outputWriter is where for all prompts will be printed. Defaults to os.Stder.
var outputWriter survey_terminal.FileWriter = os.Stderr

// CliPrompter used to prompt for cli input
type CliPrompter struct {
}

// SetOutputWriter sets the output writer to use for all survey operations
func SetOutputWriter(writer survey_terminal.FileWriter) {
outputWriter = writer
}

// stdioOption returns the IO option to use for survey functions
func stdioOption() survey.AskOpt {
return survey.WithStdio(os.Stdin, outputWriter, os.Stderr)
}

// NewCli builds a new cli prompter
func NewCli() *CliPrompter {
return &CliPrompter{}
Expand All @@ -22,7 +37,7 @@ func (cli *CliPrompter) RequestSecurityCode(pattern string) string {
prompt := &survey.Input{
Message: fmt.Sprintf("Security Token [%s]", pattern),
}
_ = survey.AskOne(prompt, &token, survey.WithValidator(survey.Required))
_ = survey.AskOne(prompt, &token, survey.WithValidator(survey.Required), stdioOption())
return token
}

Expand All @@ -34,7 +49,7 @@ func (cli *CliPrompter) ChooseWithDefault(pr string, defaultValue string, option
Options: options,
Default: defaultValue,
}
_ = survey.AskOne(prompt, &selected, survey.WithValidator(survey.Required))
_ = survey.AskOne(prompt, &selected, survey.WithValidator(survey.Required), stdioOption())

// return the selected element index
for i, option := range options {
Expand All @@ -52,7 +67,7 @@ func (cli *CliPrompter) Choose(pr string, options []string) int {
Message: pr,
Options: options,
}
_ = survey.AskOne(prompt, &selected, survey.WithValidator(survey.Required))
_ = survey.AskOne(prompt, &selected, survey.WithValidator(survey.Required), stdioOption())

// return the selected element index
for i, option := range options {
Expand All @@ -70,7 +85,7 @@ func (cli *CliPrompter) String(pr string, defaultValue string) string {
Message: pr,
Default: defaultValue,
}
_ = survey.AskOne(prompt, &val)
_ = survey.AskOne(prompt, &val, stdioOption())
return val
}

Expand All @@ -80,7 +95,7 @@ func (cli *CliPrompter) StringRequired(pr string) string {
prompt := &survey.Input{
Message: pr,
}
_ = survey.AskOne(prompt, &val, survey.WithValidator(survey.Required))
_ = survey.AskOne(prompt, &val, survey.WithValidator(survey.Required), stdioOption())
return val
}

Expand All @@ -90,6 +105,6 @@ func (cli *CliPrompter) Password(pr string) string {
prompt := &survey.Password{
Message: pr,
}
_ = survey.AskOne(prompt, &val)
_ = survey.AskOne(prompt, &val, stdioOption())
return val
}
18 changes: 15 additions & 3 deletions pkg/provider/browser/browser.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ import (

var logger = logrus.WithField("provider", "browser")

const DEFAULT_TIMEOUT float64 = 300000

// Client client for browser based Identity Provider
type Client struct {
Headless bool
// Setup alternative directory to download playwright browsers to
BrowserDriverDir string
Timeout int
}

// New create new browser based client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {
return &Client{
Headless: idpAccount.Headless,
BrowserDriverDir: idpAccount.BrowserDriverDir,
Timeout: idpAccount.Timeout,
}, nil
}

Expand Down Expand Up @@ -75,10 +79,10 @@ func (cl *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)
}
}()

return getSAMLResponse(page, loginDetails)
return getSAMLResponse(page, loginDetails, cl)
}

var getSAMLResponse = func(page playwright.Page, loginDetails *creds.LoginDetails) (string, error) {
var getSAMLResponse = func(page playwright.Page, loginDetails *creds.LoginDetails, client *Client) (string, error) {
logger.WithField("URL", loginDetails.URL).Info("opening browser")

if _, err := page.Goto(loginDetails.URL); err != nil {
Expand All @@ -94,7 +98,7 @@ var getSAMLResponse = func(page playwright.Page, loginDetails *creds.LoginDetail
}

logger.Info("waiting ...")
r, _ := page.WaitForRequest(signin_re)
r, _ := page.WaitForRequest(signin_re, client.waitForRequestTimeout())
data, err := r.PostData()
if err != nil {
return "", err
Expand Down Expand Up @@ -123,3 +127,11 @@ func (cl *Client) Validate(loginDetails *creds.LoginDetails) error {

return nil
}

func (cl *Client) waitForRequestTimeout() playwright.PageWaitForRequestOptions {
timeout := float64(cl.Timeout)
if timeout < 30000 {
timeout = DEFAULT_TIMEOUT
}
return playwright.PageWaitForRequestOptions{Timeout: &timeout}
}
48 changes: 45 additions & 3 deletions pkg/provider/browser/browser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestNoBrowserDriverFail(t *testing.T) {
assert.ErrorContains(t, err, "could not start driver")
}

func fakeSAMLResponse(page playwright.Page, loginDetails *creds.LoginDetails) (string, error) {
func fakeSAMLResponse(page playwright.Page, loginDetails *creds.LoginDetails, client *Client) (string, error) {
return response, nil
}

Expand Down Expand Up @@ -96,6 +96,14 @@ func TestGetSAMLResponse(t *testing.T) {
</saml:EncryptedAssertion>
</samlp:Response>
`

idpAccount := cfg.IDPAccount{
Headless: true,
Timeout: 100000,
}

client, err := New(&idpAccount)
assert.Nil(t, err)
params := url.Values{}
params.Add("foo1", "bar1")
params.Add("SAMLResponse", samlp)
Expand All @@ -107,12 +115,46 @@ func TestGetSAMLResponse(t *testing.T) {
regex, err := signinRegex()
assert.Nil(t, err)
page.Mock.On("Goto", url).Return(resp, nil)
page.Mock.On("WaitForRequest", regex).Return(req)
page.Mock.On("WaitForRequest", regex, client.waitForRequestTimeout()).Return(req)
req.Mock.On("PostData").Return(params.Encode(), nil)
// loginDetails := &creds.LoginDetails{
// URL: url,
//}
// samlResp, err := getSAMLResponse(page, loginDetails)
// samlResp, err := getSAMLResponse(page, loginDetails, client)
// assert.Nil(t, err)
// assert.Equal(t, samlp, samlResp)
}

func TestWaitForRequestOptions(t *testing.T) {
timeout := float64(100000)
idpAccount := cfg.IDPAccount{
Headless: true,
Timeout: int(timeout),
}

client, err := New(&idpAccount)
assert.Nil(t, err)

options := client.waitForRequestTimeout()
if *options.Timeout != timeout {
t.Errorf("Unexpected value for timeout [%.0f]: expected [%.0f]", *options.Timeout, timeout)
}
}

func TestWaitForRequestOptionsDefaultTimeout(t *testing.T) {
idpAccount := cfg.IDPAccount{
Headless: true,
Timeout: 1000,
}

client, err := New(&idpAccount)

if err != nil {
t.Errorf("Unable to create browser")
}

options := client.waitForRequestTimeout()
if *options.Timeout != DEFAULT_TIMEOUT {
t.Errorf("Unexpected value for timeout [%.0f]: expected [%.0f]", *options.Timeout, DEFAULT_TIMEOUT)
}
}

0 comments on commit ac75498

Please sign in to comment.