diff --git a/internal/config/push/push.go b/internal/config/push/push.go index d2dadddc3..1a2340060 100644 --- a/internal/config/push/push.go +++ b/internal/config/push/push.go @@ -5,6 +5,7 @@ import ( "fmt" "os" + "github.com/go-errors/errors" "github.com/spf13/afero" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" @@ -21,10 +22,17 @@ func Run(ctx context.Context, ref string, fsys afero.Fs) error { // Use base config when no remote is declared remote.ProjectId = ref } + cost, err := getCostMatrix(ctx, ref) + if err != nil { + return err + } fmt.Fprintln(os.Stderr, "Pushing config to project:", remote.ProjectId) console := utils.NewConsole() keep := func(name string) bool { title := fmt.Sprintf("Do you want to push %s config to remote?", name) + if item, exists := cost[name]; exists { + title = fmt.Sprintf("Enabling %s will cost you %s. Keep it enabled?", item.Name, item.Price) + } shouldPush, err := console.PromptYesNo(ctx, title, true) if err != nil { fmt.Fprintln(os.Stderr, err) @@ -33,3 +41,27 @@ func Run(ctx context.Context, ref string, fsys afero.Fs) error { } return client.UpdateRemoteConfig(ctx, remote, keep) } + +type CostItem struct { + Name string + Price string +} + +func getCostMatrix(ctx context.Context, projectRef string) (map[string]CostItem, error) { + resp, err := utils.GetSupabase().V1ListProjectAddonsWithResponse(ctx, projectRef) + if err != nil { + return nil, errors.Errorf("failed to list addons: %w", err) + } else if resp.JSON200 == nil { + return nil, errors.Errorf("unexpected list addons status %d: %s", resp.StatusCode(), string(resp.Body)) + } + costMatrix := make(map[string]CostItem, len(resp.JSON200.AvailableAddons)) + for _, addon := range resp.JSON200.AvailableAddons { + if len(addon.Variants) == 1 { + costMatrix[string(addon.Type)] = CostItem{ + Name: addon.Variants[0].Name, + Price: addon.Variants[0].Price.Description, + } + } + } + return costMatrix, nil +} diff --git a/internal/config/push/push_test.go b/internal/config/push/push_test.go new file mode 100644 index 000000000..5af8fb273 --- /dev/null +++ b/internal/config/push/push_test.go @@ -0,0 +1,114 @@ +package push + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/h2non/gock" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/cli/internal/testing/apitest" + "github.com/supabase/cli/internal/utils" +) + +func TestPushConfig(t *testing.T) { + project := apitest.RandomProjectRef() + // Setup valid access token + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + t.Run("throws error on malformed config", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + require.NoError(t, utils.WriteFile(utils.ConfigPath, []byte("malformed"), fsys)) + // Run test + err := Run(context.Background(), "", fsys) + // Check error + assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there") + }) + + t.Run("throws error on service unavailable", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock api + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/billing/addons"). + Reply(http.StatusServiceUnavailable) + // Run test + err := Run(context.Background(), project, fsys) + // Check error + assert.ErrorContains(t, err, "unexpected list addons status 503:") + }) +} + +func TestCostMatrix(t *testing.T) { + project := apitest.RandomProjectRef() + // Setup valid access token + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + t.Run("fetches cost matrix", func(t *testing.T) { + // Setup mock api + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/"+project+"/billing/addons"). + Reply(http.StatusOK). + SetHeader("Content-Type", "application/json"). + BodyString(`{ + "available_addons":[{ + "name": "Advanced MFA - Phone", + "type": "auth_mfa_phone", + "variants": [{ + "id": "auth_mfa_phone_default", + "name": "Advanced MFA - Phone", + "price": { + "amount": 0.1027, + "description": "$75/month, then $10/month", + "interval": "hourly", + "type": "usage" + } + }] + }, { + "name": "Advanced MFA - WebAuthn", + "type": "auth_mfa_web_authn", + "variants": [{ + "id": "auth_mfa_web_authn_default", + "name": "Advanced MFA - WebAuthn", + "price": { + "amount": 0.1027, + "description": "$75/month, then $10/month", + "interval": "hourly", + "type": "usage" + } + }] + }] + }`) + // Run test + cost, err := getCostMatrix(context.Background(), project) + // Check error + assert.NoError(t, err) + require.Len(t, cost, 2) + assert.Equal(t, "Advanced MFA - Phone", cost["auth_mfa_phone"].Name) + assert.Equal(t, "$75/month, then $10/month", cost["auth_mfa_phone"].Price) + assert.Equal(t, "Advanced MFA - WebAuthn", cost["auth_mfa_web_authn"].Name) + assert.Equal(t, "$75/month, then $10/month", cost["auth_mfa_web_authn"].Price) + }) + + t.Run("throws error on network error", func(t *testing.T) { + errNetwork := errors.New("network error") + // Setup mock api + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/billing/addons"). + ReplyError(errNetwork) + // Run test + cost, err := getCostMatrix(context.Background(), project) + // Check error + assert.ErrorIs(t, err, errNetwork) + assert.Nil(t, cost) + }) +} diff --git a/pkg/config/auth.go b/pkg/config/auth.go index 9dc0d3028..4873fcb8e 100644 --- a/pkg/config/auth.go +++ b/pkg/config/auth.go @@ -1326,14 +1326,31 @@ func (w *web3) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) { } } -func (a *auth) DiffWithRemote(remoteConfig v1API.AuthConfigResponse) ([]byte, error) { +func (a *auth) DiffWithRemote(remoteConfig v1API.AuthConfigResponse, filter ...func(string) bool) ([]byte, error) { copy := a.Clone() + copy.FromRemoteAuthConfig(remoteConfig) + // Confirm cost before enabling addons + for _, keep := range filter { + if a.MFA.Phone.VerifyEnabled && !copy.MFA.Phone.VerifyEnabled { + if !keep(string(v1API.ListProjectAddonsResponseAvailableAddonsTypeAuthMfaPhone)) { + a.MFA.Phone.VerifyEnabled = false + // Enroll cannot be enabled on its own + a.MFA.Phone.EnrollEnabled = false + } + } + if a.MFA.WebAuthn.VerifyEnabled && !copy.MFA.WebAuthn.VerifyEnabled { + if !keep(string(v1API.ListProjectAddonsResponseAvailableAddonsTypeAuthMfaWebAuthn)) { + a.MFA.WebAuthn.VerifyEnabled = false + // Enroll cannot be enabled on its own + a.MFA.WebAuthn.EnrollEnabled = false + } + } + } // Convert the config values into easily comparable remoteConfig values - currentValue, err := ToTomlBytes(copy) + currentValue, err := ToTomlBytes(a) if err != nil { return nil, err } - copy.FromRemoteAuthConfig(remoteConfig) remoteCompare, err := ToTomlBytes(copy) if err != nil { return nil, err diff --git a/pkg/config/updater.go b/pkg/config/updater.go index 4e2159488..d53b43fa8 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -143,7 +143,7 @@ func (u *ConfigUpdater) UpdateAuthConfig(ctx context.Context, projectRef string, } else if authConfig.JSON200 == nil { return errors.Errorf("unexpected status %d: %s", authConfig.StatusCode(), string(authConfig.Body)) } - authDiff, err := c.DiffWithRemote(*authConfig.JSON200) + authDiff, err := c.DiffWithRemote(*authConfig.JSON200, filter...) if err != nil { return err } else if len(authDiff) == 0 {