Skip to content

Commit

Permalink
Code formatting and cleanup (#442)
Browse files Browse the repository at this point in the history
* Formatting, extract error messages
* More refactor
* Move errors to a separate file
* Add missing file
* Misc
  • Loading branch information
sosedoff authored Sep 29, 2019
1 parent c4db197 commit 7a64500
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 65 deletions.
117 changes: 69 additions & 48 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"encoding/base64"
"errors"
"fmt"
neturl "net/url"
"regexp"
Expand Down Expand Up @@ -31,11 +30,11 @@ var (
func DB(c *gin.Context) *client.Client {
if command.Opts.Sessions {
return DbSessions[getSessionId(c.Request)]
} else {
return DbClient
}
return DbClient
}

// setClient sets the database client connection for the sessions
func setClient(c *gin.Context, newClient *client.Client) error {
currentClient := DB(c)
if currentClient != nil {
Expand All @@ -47,23 +46,26 @@ func setClient(c *gin.Context, newClient *client.Client) error {
return nil
}

sessionId := getSessionId(c.Request)
if sessionId == "" {
return errors.New("Session ID is required")
sid := getSessionId(c.Request)
if sid == "" {
return errSessionRequired
}

DbSessions[sessionId] = newClient
DbSessions[sid] = newClient
return nil
}

// GetHome renderes the home page
func GetHome(c *gin.Context) {
serveStaticAsset("/index.html", c)
}

// GetAsset renders the requested static asset
func GetAsset(c *gin.Context) {
serveStaticAsset(c.Params.ByName("path"), c)
}

// GetSessions renders the number of active sessions
func GetSessions(c *gin.Context) {
// In debug mode endpoint will return a lot of sensitive information
// like full database connection string and all query history.
Expand All @@ -74,6 +76,7 @@ func GetSessions(c *gin.Context) {
successResponse(c, gin.H{"sessions": len(DbSessions)})
}

// ConnectWithBackend creates a new connection based on backend resource
func ConnectWithBackend(c *gin.Context) {
// Setup a new backend client
backend := Backend{
Expand All @@ -90,12 +93,12 @@ func ConnectWithBackend(c *gin.Context) {
}

// Make the new session
sessionId, err := securerandom.Uuid()
sid, err := securerandom.Uuid()
if err != nil {
badRequest(c, err)
return
}
c.Request.Header.Add("x-session-id", sessionId)
c.Request.Header.Add("x-session-id", sid)

// Connect to the database
cl, err := client.NewFromUrl(cred.DatabaseURL, nil)
Expand All @@ -116,20 +119,22 @@ func ConnectWithBackend(c *gin.Context) {
return
}

c.Redirect(301, fmt.Sprintf("/%s?session=%s", command.Opts.Prefix, sessionId))
redirectURI := fmt.Sprintf("/%s?session=%s", command.Opts.Prefix, sid)
c.Redirect(301, redirectURI)
}

// Connect creates a new client connection
func Connect(c *gin.Context) {
if command.Opts.LockSession {
badRequest(c, "Session is locked")
badRequest(c, errSessionLocked)
return
}

var sshInfo *shared.SSHInfo
url := c.Request.FormValue("url")

if url == "" {
badRequest(c, "Url parameter is required")
badRequest(c, errURLRequired)
return
}

Expand Down Expand Up @@ -170,9 +175,10 @@ func Connect(c *gin.Context) {
successResponse(c, info.Format()[0])
}

// SwitchDb perform database switch for the client connection
func SwitchDb(c *gin.Context) {
if command.Opts.LockSession {
badRequest(c, "Session is locked")
badRequest(c, errSessionLocked)
return
}

Expand All @@ -181,31 +187,30 @@ func SwitchDb(c *gin.Context) {
name = c.Request.FormValue("db")
}
if name == "" {
badRequest(c, "Database name is not provided")
badRequest(c, errDatabaseNameRequired)
return
}

conn := DB(c)
if conn == nil {
badRequest(c, "Not connected")
badRequest(c, errNotConnected)
return
}

// Do not allow switching databases for connections from third-party backends
if conn.External {
badRequest(c, "Session is locked")
badRequest(c, errSessionLocked)
return
}

currentUrl, err := neturl.Parse(conn.ConnectionString)
currentURL, err := neturl.Parse(conn.ConnectionString)
if err != nil {
badRequest(c, "Unable to parse current connection string")
badRequest(c, errInvalidConnString)
return
}
currentURL.Path = name

currentUrl.Path = name

cl, err := client.NewFromUrl(currentUrl.String(), nil)
cl, err := client.NewFromUrl(currentURL.String(), nil)
if err != nil {
badRequest(c, err)
return
Expand All @@ -232,16 +237,16 @@ func SwitchDb(c *gin.Context) {
successResponse(c, info.Format()[0])
}

// Disconnect closes the current database connection
func Disconnect(c *gin.Context) {
if command.Opts.LockSession {
badRequest(c, "Session is locked")
badRequest(c, errSessionLocked)
return
}

conn := DB(c)

if conn == nil {
badRequest(c, "Not connected")
badRequest(c, errNotConnected)
return
}

Expand All @@ -254,53 +259,59 @@ func Disconnect(c *gin.Context) {
successResponse(c, gin.H{"success": true})
}

func GetDatabases(c *gin.Context) {
conn := DB(c)
if conn.External {
errorResponse(c, 403, "Not permitted")
return
}

names, err := DB(c).Databases()
serveResult(c, names, err)
}

func GetObjects(c *gin.Context) {
result, err := DB(c).Objects()
if err != nil {
badRequest(c, err)
return
}
successResponse(c, client.ObjectsFromResult(result))
}

// RunQuery executes the query
func RunQuery(c *gin.Context) {
query := cleanQuery(c.Request.FormValue("query"))

if query == "" {
badRequest(c, "Query parameter is missing")
badRequest(c, errQueryRequired)
return
}

HandleQuery(query, c)
}

// ExplainQuery renders query analyze profile
func ExplainQuery(c *gin.Context) {
query := cleanQuery(c.Request.FormValue("query"))

if query == "" {
badRequest(c, "Query parameter is missing")
badRequest(c, errQueryRequired)
return
}

HandleQuery(fmt.Sprintf("EXPLAIN ANALYZE %s", query), c)
}

// GetDatabases renders a list of all databases on the server
func GetDatabases(c *gin.Context) {
conn := DB(c)
if conn.External {
errorResponse(c, 403, errNotPermitted)
return
}

names, err := DB(c).Databases()
serveResult(c, names, err)
}

// GetObjects renders a list of database objects
func GetObjects(c *gin.Context) {
result, err := DB(c).Objects()
if err != nil {
badRequest(c, err)
return
}
successResponse(c, client.ObjectsFromResult(result))
}

// GetSchemas renders list of available schemas
func GetSchemas(c *gin.Context) {
res, err := DB(c).Schemas()
serveResult(c, res, err)
}

// GetTable renders table information
func GetTable(c *gin.Context) {
var res *client.Result
var err error
Expand All @@ -314,6 +325,7 @@ func GetTable(c *gin.Context) {
serveResult(c, res, err)
}

// GetTableRows renders table rows
func GetTableRows(c *gin.Context) {
offset, err := parseIntFormValue(c, "offset", 0)
if err != nil {
Expand Down Expand Up @@ -366,6 +378,7 @@ func GetTableRows(c *gin.Context) {
serveResult(c, res, err)
}

// GetTableInfo renders a selected table information
func GetTableInfo(c *gin.Context) {
res, err := DB(c).TableInfo(c.Params.ByName("table"))
if err == nil {
Expand All @@ -375,10 +388,12 @@ func GetTableInfo(c *gin.Context) {
}
}

// GetHistory renders a list of recent queries
func GetHistory(c *gin.Context) {
successResponse(c, DB(c).History)
}

// GetConnectionInfo renders information about current connection
func GetConnectionInfo(c *gin.Context) {
res, err := DB(c).Info()

Expand All @@ -393,21 +408,25 @@ func GetConnectionInfo(c *gin.Context) {
successResponse(c, info)
}

// GetActivity renders a list of running queries
func GetActivity(c *gin.Context) {
res, err := DB(c).Activity()
serveResult(c, res, err)
}

// GetTableIndexes renders a list of database table indexes
func GetTableIndexes(c *gin.Context) {
res, err := DB(c).TableIndexes(c.Params.ByName("table"))
serveResult(c, res, err)
}

// GetTableConstraints renders a list of database constraints
func GetTableConstraints(c *gin.Context) {
res, err := DB(c).TableConstraints(c.Params.ByName("table"))
serveResult(c, res, err)
}

// HandleQuery runs the database query
func HandleQuery(query string, c *gin.Context) {
rawQuery, err := base64.StdEncoding.DecodeString(desanitize64(query))
if err == nil {
Expand Down Expand Up @@ -443,11 +462,13 @@ func HandleQuery(query string, c *gin.Context) {
}
}

// GetBookmarks renders the list of available bookmarks
func GetBookmarks(c *gin.Context) {
bookmarks, err := bookmarks.ReadAll(bookmarks.Path(command.Opts.BookmarksDir))
serveResult(c, bookmarks, err)
}

// GetInfo renders the pgweb system information
func GetInfo(c *gin.Context) {
successResponse(c, gin.H{
"version": command.Version,
Expand All @@ -456,7 +477,7 @@ func GetInfo(c *gin.Context) {
})
}

// Export database or table data
// DataExport performs database table export
func DataExport(c *gin.Context) {
db := DB(c)

Expand All @@ -473,7 +494,7 @@ func DataExport(c *gin.Context) {
// If pg_dump is not available the following code will not show an error in browser.
// This is due to the headers being written first.
if !dump.CanExport() {
badRequest(c, "pg_dump is not found")
badRequest(c, errPgDumpNotFound)
return
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/api/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred
log.Println("Unable to fetch backend credential:", err)

// We dont want to expose the url of the backend here, so reply with generic error
return nil, fmt.Errorf("Unable to connect to the auth backend")
return nil, errBackendConnectError
}
defer resp.Body.Close()

Expand All @@ -67,7 +67,7 @@ func (be Backend) FetchCredential(resource string, c *gin.Context) (*BackendCred
return nil, err
}
if cred.DatabaseURL == "" {
return nil, fmt.Errorf("Database URL was not provided")
return nil, errConnStringRequired
}

return cred, nil
Expand Down
19 changes: 19 additions & 0 deletions pkg/api/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package api

import (
"errors"
)

var (
errNotConnected = errors.New("Not connected")
errNotPermitted = errors.New("Not permitted")
errConnStringRequired = errors.New("Connection string is required")
errInvalidConnString = errors.New("Invalid connection string")
errSessionRequired = errors.New("Session ID is required")
errSessionLocked = errors.New("Session is locked")
errURLRequired = errors.New("URL parameter is required")
errQueryRequired = errors.New("Query parameter is required")
errDatabaseNameRequired = errors.New("Database name is required")
errPgDumpNotFound = errors.New("pg_dump utility is not found")
errBackendConnectError = errors.New("Unable to connect to the auth backend")
)
Loading

0 comments on commit 7a64500

Please sign in to comment.