From 22179abbffa061e325d6c24264e864e6afd4ad52 Mon Sep 17 00:00:00 2001 From: Periklis Tsirakidis Date: Wed, 28 May 2025 22:46:45 +0200 Subject: [PATCH] fix(ruler): Return StatusBadRequest on multiple org IDs --- pkg/ruler/base/api.go | 21 +++++++- pkg/ruler/base/api_test.go | 107 +++++++++++++++++++++++++++---------- pkg/ruler/base/ruler.go | 2 +- 3 files changed, 101 insertions(+), 29 deletions(-) diff --git a/pkg/ruler/base/api.go b/pkg/ruler/base/api.go index 4e4d71d5691db..58f7c6fa87fa0 100644 --- a/pkg/ruler/base/api.go +++ b/pkg/ruler/base/api.go @@ -151,6 +151,11 @@ func (a *API) PrometheusRules(w http.ResponseWriter, req *http.Request) { logger := util_log.WithContext(req.Context(), a.logger) userID, err := tenant.TenantID(req.Context()) if err != nil || userID == "" { + if errors.Is(err, user.ErrTooManyOrgIDs) { + respondInvalidRequest(logger, w, "too many org ids found") + return + } + level.Error(logger).Log("msg", "error extracting org id from context", "err", err) respondServerError(logger, w, "no valid org id found") return @@ -177,8 +182,12 @@ func (a *API) PrometheusRules(w http.ResponseWriter, req *http.Request) { } rgs, err := a.ruler.GetRules(req.Context(), &rulesReq) - if err != nil { + if errors.Is(err, user.ErrTooManyOrgIDs) { + respondInvalidRequest(logger, w, "too many org ids found") + return + } + respondServerError(logger, w, err.Error()) return } @@ -263,6 +272,11 @@ func (a *API) PrometheusAlerts(w http.ResponseWriter, req *http.Request) { logger := util_log.WithContext(req.Context(), a.logger) userID, err := tenant.TenantID(req.Context()) if err != nil || userID == "" { + if errors.Is(err, user.ErrTooManyOrgIDs) { + respondInvalidRequest(logger, w, "too many org ids found") + return + } + level.Error(logger).Log("msg", "error extracting org id from context", "err", err) respondServerError(logger, w, "no valid org id found") return @@ -272,6 +286,11 @@ func (a *API) PrometheusAlerts(w http.ResponseWriter, req *http.Request) { rgs, err := a.ruler.GetRules(req.Context(), &RulesRequest{Filter: AlertingRule}) if err != nil { + if errors.Is(err, user.ErrTooManyOrgIDs) { + respondInvalidRequest(logger, w, "too many org ids found") + return + } + respondServerError(logger, w, err.Error()) return } diff --git a/pkg/ruler/base/api_test.go b/pkg/ruler/base/api_test.go index fc440c0b5462b..df5a9d74aee9c 100644 --- a/pkg/ruler/base/api_test.go +++ b/pkg/ruler/base/api_test.go @@ -81,6 +81,7 @@ func TestRuler_PrometheusRules(t *testing.T) { } testCases := map[string]struct { + tenantID string configuredRules rulespb.RuleGroupList expectedConfigured int expectedStatusCode int @@ -89,6 +90,7 @@ func TestRuler_PrometheusRules(t *testing.T) { queryParams string }{ "should load and evaluate the configured rules": { + tenantID: userID, configuredRules: rulespb.RuleGroupList{ &rulespb.RuleGroupDesc{ Name: "group1", @@ -133,6 +135,7 @@ func TestRuler_PrometheusRules(t *testing.T) { Interval: interval, }, }, + tenantID: userID, expectedConfigured: 1, expectedRules: []*RuleGroup{ { @@ -168,6 +171,7 @@ func TestRuler_PrometheusRules(t *testing.T) { Interval: interval, }, }, + tenantID: userID, expectedConfigured: 1, queryParams: "?type=alert", expectedRules: []*RuleGroup{ @@ -198,6 +202,7 @@ func TestRuler_PrometheusRules(t *testing.T) { Interval: interval, }, }, + tenantID: userID, expectedConfigured: 1, queryParams: "?type=record", expectedRules: []*RuleGroup{ @@ -217,6 +222,7 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "Invalid type param": { + tenantID: userID, configuredRules: rulespb.RuleGroupList{}, expectedConfigured: 0, queryParams: "?type=foo", @@ -224,13 +230,24 @@ func TestRuler_PrometheusRules(t *testing.T) { expectedErrorType: v1.ErrBadData, expectedRules: []*RuleGroup{}, }, + "Too many org ids": { + tenantID: "user1|user2|user3", + configuredRules: rulespb.RuleGroupList{}, + expectedConfigured: 0, + queryParams: "?type=record", + expectedStatusCode: http.StatusBadRequest, + expectedErrorType: v1.ErrBadData, + expectedRules: []*RuleGroup{}, + }, "when filtering by an unknown namespace then the API returns nothing": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?file=unknown", expectedRules: []*RuleGroup{}, }, "when filtering by a single known namespace then the API returns only rules from that namespace": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"file": []string{namespaceName(1)}}.Encode(), @@ -265,6 +282,7 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "when filtering by a multiple known namespaces then the API returns rules from both namespaces": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"file": []string{namespaceName(1), namespaceName(2)}}.Encode(), @@ -326,12 +344,14 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "when filtering by an unknown group then the API returns nothing": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?rule_group=unknown", expectedRules: []*RuleGroup{}, }, "when filtering by a known group then the API returns only rules from that group": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"rule_group": []string{groupName(2)}}.Encode(), @@ -366,6 +386,7 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "when filtering by multiple known groups then the API returns rules from both groups": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"rule_group": []string{groupName(2), groupName(3)}}.Encode(), @@ -428,12 +449,14 @@ func TestRuler_PrometheusRules(t *testing.T) { }, "when filtering by an unknown rule name then the API returns all empty groups": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?rule_name=unknown", expectedRules: []*RuleGroup{}, }, "when filtering by a known rule name then the API returns only rules with that name": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"rule_name": []string{"UniqueNamedRuleN1G2"}}.Encode(), @@ -449,6 +472,7 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "when filtering by multiple known rule names then the API returns both rules": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{"rule_name": []string{"UniqueNamedRuleN1G2", "UniqueNamedRuleN2G3"}}.Encode(), @@ -472,6 +496,7 @@ func TestRuler_PrometheusRules(t *testing.T) { }, }, "when filtering by a known namespace and group then the API returns only rules from that namespace and group": { + tenantID: userID, configuredRules: makeFilterTestRules(), expectedConfigured: len(makeFilterTestRules()), queryParams: "?" + url.Values{ @@ -516,7 +541,7 @@ func TestRuler_PrometheusRules(t *testing.T) { a := NewAPI(r, r.store, log.NewNopLogger()) - req := requestFor(t, "GET", "https://localhost:8080/api/prom/api/v1/rules"+tc.queryParams, nil, "user1") + req := requestFor(t, "GET", "https://localhost:8080/api/prom/api/v1/rules"+tc.queryParams, nil, tc.tenantID) w := httptest.NewRecorder() a.PrometheusRules(w, req) @@ -558,35 +583,63 @@ func TestRuler_PrometheusRules(t *testing.T) { func TestRuler_PrometheusAlerts(t *testing.T) { cfg := defaultRulerConfig(t, newMockRuleStore(mockRules)) - r := newTestRuler(t, cfg) - defer services.StopAndAwaitTerminated(context.Background(), r) //nolint:errcheck + tests := []struct { + name string + tenantID string + expectedStatusCode int + expectedResponse response + }{ + { + name: "single org id", + tenantID: "user1", + expectedStatusCode: http.StatusOK, + expectedResponse: response{ + Status: "success", + Data: &AlertDiscovery{ + Alerts: []*Alert{}, + }, + }, + }, + { + name: "multiple org ids", + tenantID: "user1|user2|user3", + expectedStatusCode: http.StatusBadRequest, + expectedResponse: response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "too many org ids found", + }, + }, + } - a := NewAPI(r, r.store, log.NewNopLogger()) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newTestRuler(t, cfg) + defer services.StopAndAwaitTerminated(context.Background(), r) //nolint:errcheck - req := requestFor(t, http.MethodGet, "https://localhost:8080/api/prom/api/v1/alerts", nil, "user1") - w := httptest.NewRecorder() - a.PrometheusAlerts(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - // Check status code and status response - responseJSON := response{} - err := json.Unmarshal(body, &responseJSON) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, responseJSON.Status, "success") - - // Currently there is not an easy way to mock firing alerts. The empty - // response case is tested instead. - expectedResponse, _ := json.Marshal(response{ - Status: "success", - Data: &AlertDiscovery{ - Alerts: []*Alert{}, - }, - }) + a := NewAPI(r, r.store, log.NewNopLogger()) + + req := requestFor(t, http.MethodGet, "https://localhost:8080/api/prom/api/v1/alerts", nil, test.tenantID) + w := httptest.NewRecorder() + a.PrometheusAlerts(w, req) - require.Equal(t, string(expectedResponse), string(body)) + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + // Check status code and status response + responseJSON := response{} + err := json.Unmarshal(body, &responseJSON) + require.NoError(t, err) + require.Equal(t, test.expectedStatusCode, resp.StatusCode) + require.Equal(t, test.expectedResponse.Status, responseJSON.Status) + + // Currently there is not an easy way to mock firing alerts. The empty + // response case is tested instead. + expectedResponse, err := json.Marshal(test.expectedResponse) + require.NoError(t, err) + require.Equal(t, string(expectedResponse), string(body)) + }) + } } func TestRuler_GetRulesLabelFilter(t *testing.T) { diff --git a/pkg/ruler/base/ruler.go b/pkg/ruler/base/ruler.go index 4e466810570b1..063538840d43f 100644 --- a/pkg/ruler/base/ruler.go +++ b/pkg/ruler/base/ruler.go @@ -802,7 +802,7 @@ func RemoveRuleTokenFromGroupName(name string) string { func (r *Ruler) GetRules(ctx context.Context, req *RulesRequest) ([]*GroupStateDesc, error) { userID, err := tenant.TenantID(ctx) if err != nil { - return nil, fmt.Errorf("no user id found in context") + return nil, fmt.Errorf("no user id found in context: %w", err) } if r.cfg.EnableSharding {