Skip to content

Commit

Permalink
feat: Allow passing an auth token when downloading Hub plugins (#138)
Browse files Browse the repository at this point in the history

Related to cloudquery/cloudquery-issues#737 (internal issue).
This should unlock paid plugins downloads from the Hub. CLI PR coming in a bit

---
  • Loading branch information
erezrokah authored Oct 26, 2023
1 parent f92d39c commit b9f491d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
11 changes: 8 additions & 3 deletions managedplugin/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func getURLLocation(ctx context.Context, org string, name string, version string
return "", fmt.Errorf("failed to find plugin %s/%s version %s", org, name, version)
}

func DownloadPluginFromHub(ctx context.Context, localPath string, team string, name string, version string, typ PluginType) error {
func DownloadPluginFromHub(ctx context.Context, authToken, localPath, team, name, version string, typ PluginType) error {
downloadDir := filepath.Dir(localPath)
if _, err := os.Stat(localPath); err == nil {
return nil
Expand All @@ -97,7 +97,12 @@ func DownloadPluginFromHub(ctx context.Context, localPath string, team string, n
return http.ErrUseLastResponse
},
}
c, err := cloudquery_api.NewClient(APIBaseURL)
c, err := cloudquery_api.NewClient(APIBaseURL, cloudquery_api.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error {
if authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
}
return nil
}))
c.Client = client
if err != nil {
return fmt.Errorf("failed to create Hub API client: %w", err)
Expand All @@ -109,7 +114,7 @@ func DownloadPluginFromHub(ctx context.Context, localPath string, team string, n
}
defer downloadURL.Body.Close()
if downloadURL.StatusCode == http.StatusNotFound {
return fmt.Errorf("failed to get plugin url for %v %v/%v@%v: plugin version not found", typ, team, name, version)
return fmt.Errorf("failed to get plugin url for %v %v/%v@%v: plugin version not found. If you're trying to use a paid plugin you'll need to run `cloudquery login` first", typ, team, name, version)
}
if downloadURL.StatusCode == http.StatusTooManyRequests {
return fmt.Errorf("failed to get plugin url for %v %v/%v@%v: too many requests. Try logging in via `cloudquery login` to increase rate limits", typ, team, name, version)
Expand Down
6 changes: 3 additions & 3 deletions managedplugin/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestDownloadPluginFromGithubIntegration(t *testing.T) {
}
}

func TestDownloadPluginFromCloudQueryIntegration(t *testing.T) {
func TestDownloadPluginFromCloudQueryHub(t *testing.T) {
tmp := t.TempDir()
cases := []struct {
testName string
Expand All @@ -44,11 +44,11 @@ func TestDownloadPluginFromCloudQueryIntegration(t *testing.T) {
wantErr bool
typ PluginType
}{
{testName: "should download test plugin from cloudquery registry", team: "cloudquery", plugin: "test", version: "v3.1.11", typ: PluginSource},
{testName: "should download test plugin from cloudquery registry", team: "cloudquery", plugin: "azuredevops", version: "v3.0.12", typ: PluginSource},
}
for _, tc := range cases {
t.Run(tc.testName, func(t *testing.T) {
err := DownloadPluginFromHub(context.Background(), path.Join(tmp, tc.testName), tc.team, tc.plugin, tc.version, tc.typ)
err := DownloadPluginFromHub(context.Background(), "", path.Join(tmp, tc.testName), tc.team, tc.plugin, tc.version, tc.typ)
if (err != nil) != tc.wantErr {
t.Errorf("TestDownloadPluginFromCloudQueryIntegration() error = %v, wantErr %v", err, tc.wantErr)
return
Expand Down
6 changes: 6 additions & 0 deletions managedplugin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ func WithOtelEndpointInsecure() Option {
c.otelEndpointInsecure = true
}
}

func WithAuthToken(authToken string) Option {
return func(c *Client) {
c.authToken = authToken
}
}
3 changes: 2 additions & 1 deletion managedplugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type Client struct {
otelEndpointInsecure bool
metrics *Metrics
registry Registry
authToken string
}

// typ will be deprecated soon but now required for a transition period
Expand Down Expand Up @@ -181,7 +182,7 @@ func (c *Client) downloadPlugin(ctx context.Context, typ PluginType) error {
org, name := pathSplit[0], pathSplit[1]
c.LocalPath = filepath.Join(c.directory, "plugins", typ.String(), org, name, c.config.Version, "plugin")
c.LocalPath = WithBinarySuffix(c.LocalPath)
return DownloadPluginFromHub(ctx, c.LocalPath, org, name, c.config.Version, typ)
return DownloadPluginFromHub(ctx, c.authToken, c.LocalPath, org, name, c.config.Version, typ)
default:
return fmt.Errorf("unknown registry %s", c.config.Registry.String())
}
Expand Down
20 changes: 10 additions & 10 deletions managedplugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,38 @@ func TestManagedPluginCloudQuery(t *testing.T) {
ctx := context.Background()
tmpDir := t.TempDir()
cfg := Config{
Name: "test",
Name: "azuredevops",
Registry: RegistryCloudQuery,
Path: "cloudquery/test",
Version: "v3.1.11",
Path: "cloudquery/azuredevops",
Version: "v3.0.12",
}
clients, err := NewClients(ctx, PluginSource, []Config{cfg}, WithDirectory(tmpDir), WithNoSentry())
if err != nil {
t.Fatal(err)
}
testClient := clients.ClientByName("test")
testClient := clients.ClientByName("azuredevops")
if testClient == nil {
t.Fatal("test client not found")
t.Fatal("azuredevops client not found")
}
if err := clients.Terminate(); err != nil {
t.Fatal(err)
}
localPath := filepath.Join(tmpDir, "plugins", PluginSource.String(), "cloudquery", "test", cfg.Version, "plugin")
localPath := filepath.Join(tmpDir, "plugins", PluginSource.String(), "cloudquery", "azuredevops", cfg.Version, "plugin")
localPath = WithBinarySuffix(localPath)
cfg = Config{
Name: "test",
Name: "azuredevops",
Registry: RegistryLocal,
Path: localPath,
Version: "v3.1.11",
Version: "v3.0.12",
}

clients, err = NewClients(ctx, PluginSource, []Config{cfg}, WithDirectory(tmpDir), WithNoSentry())
if err != nil {
t.Fatal(err)
}
testClient = clients.ClientByName("test")
testClient = clients.ClientByName("azuredevops")
if testClient == nil {
t.Fatal("test client not found")
t.Fatal("azuredevops client not found")
}
if err := clients.Terminate(); err != nil {
t.Fatal(err)
Expand Down

0 comments on commit b9f491d

Please sign in to comment.