diff --git a/cmd/postgrescmd/preparecmd.go b/cmd/postgrescmd/preparecmd.go index d1a8c89..9771560 100644 --- a/cmd/postgrescmd/preparecmd.go +++ b/cmd/postgrescmd/preparecmd.go @@ -36,6 +36,11 @@ This operation is only required to run once for each postgresql instance.`, Name: "namespace", Aliases: []string{"n"}, }, + &cli.StringFlag{ + Name: "schema", + Value: "public", + Usage: "Schema to grant access to", + }, }, Before: func(context *cli.Context) error { if context.Args().Len() < 1 { @@ -51,6 +56,7 @@ This operation is only required to run once for each postgresql instance.`, allPrivs := context.Bool("all-privs") namespace := context.String("namespace") cluster := context.String("context") + schema := context.String("schema") fmt.Println(context.Command.Description) @@ -61,7 +67,7 @@ This operation is only required to run once for each postgresql instance.`, return fmt.Errorf("cancelled by user") } - return postgres.PrepareAccess(context.Context, appName, namespace, cluster, allPrivs) + return postgres.PrepareAccess(context.Context, appName, namespace, cluster, schema, allPrivs) }, } } diff --git a/cmd/postgrescmd/revokecmd.go b/cmd/postgrescmd/revokecmd.go index 980ed60..b0d5ed3 100644 --- a/cmd/postgrescmd/revokecmd.go +++ b/cmd/postgrescmd/revokecmd.go @@ -32,6 +32,11 @@ This operation is only required to run once for each postgresql instance.`, Name: "namespace", Aliases: []string{"n"}, }, + &cli.StringFlag{ + Name: "schema", + Value: "public", + Usage: "Schema to revoke access from", + }, }, Before: func(context *cli.Context) error { if context.Args().Len() < 1 { @@ -46,6 +51,7 @@ This operation is only required to run once for each postgresql instance.`, namespace := context.String("namespace") cluster := context.String("context") + schema := context.String("schema") fmt.Println(context.Command.Description) @@ -56,7 +62,7 @@ This operation is only required to run once for each postgresql instance.`, return fmt.Errorf("cancelled by user") } - return postgres.RevokeAccess(context.Context, appName, namespace, cluster) + return postgres.RevokeAccess(context.Context, appName, namespace, cluster, schema) }, } } diff --git a/pkg/postgres/access.go b/pkg/postgres/access.go index cd48ac3..857d4db 100644 --- a/pkg/postgres/access.go +++ b/pkg/postgres/access.go @@ -3,39 +3,42 @@ package postgres import ( "context" "database/sql" + "strings" + + "github.com/lib/pq" ) -var grantAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO cloudsqliamuser; - ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO cloudsqliamuser; - GRANT ALL ON ALL TABLES IN SCHEMA public TO cloudsqliamuser; - GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO cloudsqliamuser; - GRANT CREATE ON SCHEMA public TO cloudsqliamuser;` +var grantAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA $schema GRANT ALL ON TABLES TO cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA $schema GRANT ALL ON SEQUENCES TO cloudsqliamuser; + GRANT ALL ON ALL TABLES IN SCHEMA $schema TO cloudsqliamuser; + GRANT ALL ON ALL SEQUENCES IN SCHEMA $schema TO cloudsqliamuser; + GRANT CREATE ON SCHEMA $schema TO cloudsqliamuser;` -var grantSelectPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO cloudsqliamuser; - ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO cloudsqliamuser; - GRANT SELECT ON ALL TABLES IN SCHEMA public TO cloudsqliamuser; - GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO cloudsqliamuser;` +var grantSelectPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA $schema GRANT SELECT ON TABLES TO cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA $schema GRANT SELECT ON SEQUENCES TO cloudsqliamuser; + GRANT SELECT ON ALL TABLES IN SCHEMA $schema TO cloudsqliamuser; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA $schema TO cloudsqliamuser;` // this is used for all privileges and select, as it covers both cases -var revokeAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE ALL ON TABLES FROM cloudsqliamuser; - ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE ALL ON SEQUENCES FROM cloudsqliamuser; - REVOKE ALL ON ALL TABLES IN SCHEMA public FROM cloudsqliamuser; - REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM cloudsqliamuser; - REVOKE CREATE ON SCHEMA public FROM cloudsqliamuser;` +var revokeAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA $schema REVOKE ALL ON TABLES FROM cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA $schema REVOKE ALL ON SEQUENCES FROM cloudsqliamuser; + REVOKE ALL ON ALL TABLES IN SCHEMA $schema FROM cloudsqliamuser; + REVOKE ALL ON ALL SEQUENCES IN SCHEMA $schema FROM cloudsqliamuser; + REVOKE CREATE ON SCHEMA $schema FROM cloudsqliamuser;` -func PrepareAccess(ctx context.Context, appName, namespace, cluster string, allPrivs bool) error { +func PrepareAccess(ctx context.Context, appName, namespace, cluster, schema string, allPrivs bool) error { if allPrivs { - return sqlExecAsAppUser(ctx, appName, namespace, cluster, grantAllPrivs) + return sqlExecAsAppUser(ctx, appName, namespace, cluster, schema, grantAllPrivs) } else { - return sqlExecAsAppUser(ctx, appName, namespace, cluster, grantSelectPrivs) + return sqlExecAsAppUser(ctx, appName, namespace, cluster, schema, grantSelectPrivs) } } -func RevokeAccess(ctx context.Context, appName, namespace, cluster string) error { - return sqlExecAsAppUser(ctx, appName, namespace, cluster, revokeAllPrivs) +func RevokeAccess(ctx context.Context, appName, namespace, cluster, schema string) error { + return sqlExecAsAppUser(ctx, appName, namespace, cluster, schema, revokeAllPrivs) } -func sqlExecAsAppUser(ctx context.Context, appName, namespace, cluster, statement string) error { +func sqlExecAsAppUser(ctx context.Context, appName, namespace, cluster, schema, statement string) error { dbInfo, err := NewDBInfo(appName, namespace, cluster) if err != nil { return err @@ -46,6 +49,8 @@ func sqlExecAsAppUser(ctx context.Context, appName, namespace, cluster, statemen return err } + schema = pq.QuoteIdentifier(schema) + statement = strings.ReplaceAll(statement, "$schema", schema) db, err := sql.Open("cloudsqlpostgres", connectionInfo.ProxyConnectionString()) if err != nil { return err