Skip to content

Commit

Permalink
Merge pull request #69 from piny940/e2e-test
Browse files Browse the repository at this point in the history
E2e testを作成
  • Loading branch information
piny940 authored Nov 7, 2024
2 parents 259e5b8 + 1dc3b8e commit 08f38a8
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 34 deletions.
21 changes: 18 additions & 3 deletions internal/api/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,18 @@ func (s *Server) OAuthInterfaceAuthorize(ctx context.Context, request OAuthInter
}
user, err := CurrentUser(ctx)
if errors.Is(err, ErrUnauthorized) {
url := s.Conf.LoginUrl + "?" + toQueryString(map[string]string{"next": this})
query := map[string]string{
"next": this,
"client_id": request.Params.ClientId,
"scope": request.Params.Scope,
"response_type": request.Params.ResponseType,
"error": string(OAuthAuthorizeErrUnauthorizedClient),
"error_description": "unauthorized_client",
}
if request.Params.State != nil {
query["state"] = *request.Params.State
}
url := s.Conf.LoginUrl + "?" + toQueryString(query)
return OAuthInterfaceAuthorize302Response{
Headers: OAuthInterfaceAuthorize302ResponseHeaders{
Location: url,
Expand All @@ -59,13 +70,17 @@ func (s *Server) OAuthInterfaceAuthorize(ctx context.Context, request OAuthInter
}, nil
}
if errors.Is(err, usecase.ErrNotApproved) {
url := s.Conf.ApproveUrl + "?" + toQueryString(map[string]string{
query := map[string]string{
"next": this,
"client_id": request.Params.ClientId,
"scope": request.Params.Scope,
"error": string(OAuthAuthorizeErrAccessDenied),
"error_description": "access_denied",
})
}
if request.Params.State != nil {
query["state"] = *request.Params.State
}
url := s.Conf.ApproveUrl + "?" + toQueryString(query)
return OAuthInterfaceAuthorize302Response{
Headers: OAuthInterfaceAuthorize302ResponseHeaders{
Location: url,
Expand Down
6 changes: 3 additions & 3 deletions internal/infrastructure/gateway/approval.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (a *ApprovalRepo) Approve(clientID oauth.ClientID, userID domain.UserID, sc
}
existScopes := make([]oauth.TypeScope, 0, len(existMScopes))
for _, s := range existMScopes {
existScopes = append(existScopes, scopeMap[s.ScopeID])
existScopes = append(existScopes, ScopeMap[s.ScopeID])
}
compactScopes := slices.Compact(scopes)
adds := make([]oauth.TypeScope, 0)
Expand All @@ -92,7 +92,7 @@ func (a *ApprovalRepo) Approve(clientID oauth.ClientID, userID domain.UserID, sc
}
mAdds := make([]*model.ApprovalScope, 0, len(adds))
for _, s := range adds {
scopeID, ok := scopeMapReverse[s]
scopeID, ok := ScopeMapReverse[s]
if !ok {
return oauth.ErrInvalidScope
}
Expand All @@ -110,7 +110,7 @@ func (a *ApprovalRepo) Approve(clientID oauth.ClientID, userID domain.UserID, sc
func toDomainApproval(approval *model.Approval, approvalScopes []*model.ApprovalScope) *oauth.Approval {
scopes := make([]oauth.TypeScope, 0, len(approvalScopes))
for _, s := range approvalScopes {
scopes = append(scopes, scopeMap[s.ScopeID])
scopes = append(scopes, ScopeMap[s.ScopeID])
}
return &oauth.Approval{
ID: oauth.ApprovalID(approval.ID),
Expand Down
4 changes: 2 additions & 2 deletions internal/infrastructure/gateway/approval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestApprovalApprove(t *testing.T) {
{ID: approvalID, ClientID: "client2", UserID: userID},
}
initialApprovalScopes := []*model.ApprovalScope{
{ApprovalID: approvalID, ScopeID: scopeMapReverse[oauth.ScopeOpenID]},
{ApprovalID: approvalID, ScopeID: ScopeMapReverse[oauth.ScopeOpenID]},
}
suites := []struct {
name string
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestApprovalApprove(t *testing.T) {
t.Errorf("expected: %v, got: %v", s.expectedScopes, scopes)
}
for _, scope := range scopes {
if !slices.Contains(s.expectedScopes, scopeMap[scope.ScopeID]) {
if !slices.Contains(s.expectedScopes, ScopeMap[scope.ScopeID]) {
t.Errorf("expected: %v, got: %v", s.expectedScopes, scopes)
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/infrastructure/gateway/auth_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (a *AuthCodeRepo) Create(value string, clientID oauth.ClientID, userID doma
adds := make([]*model.AuthCodeScope, 0, len(compactScopes))
for _, s := range compactScopes {
adds = append(adds, &model.AuthCodeScope{
ScopeID: scopeMapReverse[s],
ScopeID: ScopeMapReverse[s],
AuthCodeID: code.ID,
})
}
Expand All @@ -83,7 +83,7 @@ func (a *AuthCodeRepo) Create(value string, clientID oauth.ClientID, userID doma
func toDomainAuthCode(m *model.AuthCode, mScopes []*model.AuthCodeScope) *oauth.AuthCode {
scopes := make([]oauth.TypeScope, 0, len(mScopes))
for _, s := range mScopes {
scopes = append(scopes, scopeMap[s.ScopeID])
scopes = append(scopes, ScopeMap[s.ScopeID])
}
return &oauth.AuthCode{
Value: m.Value,
Expand Down
2 changes: 1 addition & 1 deletion internal/infrastructure/gateway/auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestAuthCodeCreate(t *testing.T) {
if len(actualScopes) != 1 { // 重複は排除される
t.Fatalf("unexpected len(actualScopes): %d", len(actualScopes))
}
if actualScopes[0].ScopeID != scopeMapReverse[oauth.ScopeOpenID] {
if actualScopes[0].ScopeID != ScopeMapReverse[oauth.ScopeOpenID] {
t.Errorf("unexpected actualScopes[0].ScopeID: %d", actualScopes[0].ScopeID)
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/infrastructure/gateway/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package gateway

import "auth/internal/domain/oauth"

var scopeMap = map[int32]oauth.TypeScope{
var ScopeMap = map[int32]oauth.TypeScope{
0: oauth.ScopeOpenID,
1: oauth.ScopeEmail,
}
var scopeMapReverse = map[oauth.TypeScope]int32{
var ScopeMapReverse = map[oauth.TypeScope]int32{
oauth.ScopeOpenID: 0,
oauth.ScopeEmail: 1,
}
168 changes: 168 additions & 0 deletions test/e2e/oauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package e2e

import (
"auth/internal/domain/oauth"
"auth/internal/infrastructure/gateway"
"auth/internal/infrastructure/model"
"fmt"
"io"
"net/url"
"os"
"testing"

"golang.org/x/crypto/bcrypt"
)

func TestAuthorizeCodeNotAuthenticated(t *testing.T) {
const userID = 43234
const username = "user1"
const password = "password"
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to hash password: %v", err)
}
const clientOwnerID = 32478
const client1ID = "client1"
const client2ID = "client2"
const client3ID = "client3"
initialUsers := []*model.User{
{ID: userID, Name: username, EncryptedPassword: string(hashed)},
{ID: clientOwnerID, Name: "client owner", EncryptedPassword: string(hashed)},
}
initialClients := []*model.Client{
{ID: client1ID, Name: "approved", EncryptedSecret: "secret", UserID: clientOwnerID},
{ID: client2ID, Name: "partially approved", EncryptedSecret: "secret", UserID: clientOwnerID},
{ID: client3ID, Name: "not approved", EncryptedSecret: "secret", UserID: clientOwnerID},
}
const validRedirectURI = "https://example.com/callback"
initialRedirectURIs := []*model.RedirectURI{
{ClientID: client1ID, URI: validRedirectURI},
{ClientID: client2ID, URI: validRedirectURI},
{ClientID: client3ID, URI: validRedirectURI},
}
validScope := "openid email"

const approval1ID = 32284
const approval2ID = 32285
initialApprovals := []*model.Approval{
{ID: approval1ID, UserID: userID, ClientID: client1ID},
{ID: approval2ID, UserID: userID, ClientID: client2ID},
}
initialApprovalScopes := []*model.ApprovalScope{
{ApprovalID: approval1ID, ScopeID: gateway.ScopeMapReverse[oauth.ScopeOpenID]},
{ApprovalID: approval1ID, ScopeID: gateway.ScopeMapReverse[oauth.ScopeEmail]},
{ApprovalID: approval2ID, ScopeID: gateway.ScopeMapReverse[oauth.ScopeOpenID]},
}
const state = "rfejafewiofjwefiojwoefwjofprwjfrawo"

serverUrl := os.Getenv("SERVER_URL")
apiLoginUrl := os.Getenv("API_LOGIN_URL")
apiApproveUrl := os.Getenv("API_APPROVE_URL")

suites := []struct {
name string
authenticated bool
clientID string
redirectURI string
scope string
state string
expectedStatus int
authCodeIssued bool
redirectedTo *string
}{
{"not authenticated", false, client1ID, validRedirectURI, validScope, state, 302, false, ptr(apiLoginUrl)},
{"client not found", true, "invalid", validRedirectURI, validScope, state, 400, false, nil},
{"invalid redirect uri", true, client1ID, "https://example.com/invalid", validScope, state, 400, false, nil},
{"invalid scope", true, client1ID, validRedirectURI, "invalid", state, 400, false, nil},
{"partially invalid scope", true, client1ID, validRedirectURI, "openid invalid", state, 400, false, nil},
{"empty state", true, client1ID, validRedirectURI, validScope, "", 302, true, ptr(validRedirectURI)},
{"client not approved", true, client3ID, validRedirectURI, validScope, state, 302, false, ptr(apiApproveUrl)},
{"a scope not approved", true, client2ID, validRedirectURI, validScope, state, 302, false, ptr(apiApproveUrl)},
{"client approved", true, client1ID, validRedirectURI, validScope, state, 302, true, ptr(validRedirectURI)},
{"more approved", true, client1ID, validRedirectURI, string(oauth.ScopeOpenID), state, 302, true, ptr(validRedirectURI)},
}

for _, suit := range suites {
t.Run(suit.name, func(t *testing.T) {
s := newServer(t)
defer s.Close()
seed(t, initialUsers, initialClients, initialRedirectURIs, initialApprovals, initialApprovalScopes)

var cookie *string
if suit.authenticated {
_, c := login(t, s, username, password)
cookie = &c
}
query := map[string]string{
"response_type": "code",
"client_id": suit.clientID,
"redirect_uri": suit.redirectURI,
"scope": suit.scope,
}
if suit.state != "" {
query["state"] = suit.state
}
res := authedGet(t, s.URL+"/oauth/authorize?"+mapToQuery(t, query), cookie)
defer res.Body.Close()

if res.StatusCode != suit.expectedStatus {
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
t.Fatalf("expected status code: %d, but got %d. response body: %v", suit.expectedStatus, res.StatusCode, string(body))
}
if suit.expectedStatus == 302 {
actual, err := url.Parse(res.Header.Get("Location"))
if err != nil {
t.Fatalf("failed to parse url: %v", err)
}
{ // test the redirect target is correct
actualUrl := fmt.Sprintf("%s://%s%s", actual.Scheme, actual.Host, actual.Path)
if actualUrl != *suit.redirectedTo {
t.Errorf("unexpected url: %v, expect %v", actualUrl, *suit.redirectedTo)
}
}
actualQuery := actual.Query()
if suit.state != "" {
if actualQuery.Get("state") != suit.state {
t.Errorf("unexpected state: %v", actualQuery.Get("state"))
}
}
if suit.authCodeIssued {
if actualQuery.Get("code") == "" {
t.Errorf("code is not issued")
}
} else {
next, err := url.Parse(actualQuery.Get("next"))
if err != nil {
t.Fatalf("failed to parse next url: %v", err)
}
actualUrl := fmt.Sprintf("%s://%s%s", next.Scheme, next.Host, next.Path)
expectedUrl := fmt.Sprintf("%s%s", serverUrl, "/oauth/authorize")
if actualUrl != expectedUrl {
t.Errorf("unexpected url: %v, expect %v", actualUrl, expectedUrl)
}
nextQuery := next.Query()
if nextQuery.Get("response_type") != "code" {
t.Errorf("unexpected response_type: %v, expect code", nextQuery.Get("response_type"))
}
if nextQuery.Get("client_id") != suit.clientID {
t.Errorf("unexpected client_id: %v, expect %v", nextQuery.Get("client_id"), suit.clientID)
}
if nextQuery.Get("redirect_uri") != suit.redirectURI {
t.Errorf("unexpected redirect_uri: %v, expect %v", nextQuery.Get("redirect_uri"), suit.redirectURI)
}
if nextQuery.Get("scope") != suit.scope {
t.Errorf("unexpected scope: %v, expect %v", nextQuery.Get("scope"), suit.scope)
}
if suit.state != "" {
if nextQuery.Get("state") != suit.state {
t.Errorf("unexpected state: %v, expect %v", nextQuery.Get("state"), suit.state)
}
}
}
}
})
}
}
49 changes: 48 additions & 1 deletion test/e2e/user_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
package e2e

import (
"auth/internal/api"
"bytes"
"net/http"
"testing"
)

func TestSignupLogin(t *testing.T) {
s := newServer(t)
defer s.Close()

login(t, s)
name := randomString(t, 10)
password := randomString(t, 16)

{ // signup
signupInput := &api.UsersReqSignup{
Name: name,
Password: password,
PasswordConfirmation: password,
}
body := toJSON(t, signupInput)
resp, err := http.Post(s.URL+"/users/signup", "application/json", bytes.NewBuffer(body))
if err != nil {
t.Fatalf("failed to post: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("failed to create user: %v", resp.StatusCode)
}
}

user, cookie := login(t, s, name, password)
if user.Name != name {
t.Fatalf("unexpected name: %v", user.Name)
}

// test cookie is valid

req, err := http.NewRequest(http.MethodGet, s.URL+"/session", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Cookie", cookie)
resp, err := (&http.Client{}).Do(req)
if err != nil {
t.Fatalf("failed to get: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("failed to get: %v", resp.StatusCode)
}
resBody := &struct{ User *api.User }{}
fromJSONBody(t, resp.Body, resBody)
if resBody.User.Name != name {
t.Fatalf("unexpected name: %v", resBody.User.Name)
}
}
Loading

0 comments on commit 08f38a8

Please sign in to comment.