Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions pkg/limits/frontend/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package frontend

import (
"context"
"errors"
"fmt"
"testing"

Expand All @@ -16,8 +17,8 @@ import (
"github.com/grafana/loki/v3/pkg/limits/proto"
)

// mockExceedsLimitsGatherer mocks an ExeceedsLimitsGatherer. It avoids having to
// set up a mock ring to test the frontend.
// mockExceedsLimitsGatherer mocks an ExeceedsLimitsGatherer. It avoids having
// to set up a mock ring to test the frontend.
type mockExceedsLimitsGatherer struct {
t *testing.T

Expand All @@ -37,50 +38,59 @@ type mockIngestLimitsClient struct {
proto.IngestLimitsClient
t *testing.T

// The requests expected to be received.
expectedAssignedPartitionsRequest *proto.GetAssignedPartitionsRequest
expectedExceedsLimitsRequest *proto.ExceedsLimitsRequest
// The complete set of expected requests over the lifetime of the client.
// We don't check the expected requests for GetAssignedPartitions as it
// has no fields. Instead, tests should check the number of requests
// received with [Finished].
expectedExceedsLimitsRequests []*proto.ExceedsLimitsRequest

// The mocked responses.
getAssignedPartitionsResponse *proto.GetAssignedPartitionsResponse
getAssignedPartitionsResponseErr error
exceedsLimitsResponse *proto.ExceedsLimitsResponse
exceedsLimitsResponseErr error

// The expected request counts.
expectedNumAssignedPartitionsRequests int
expectedNumExceedsLimitsRequests int
// The complete set of mocked responses over the lifetime of the client.
// When a request is received, it consumes the next response (or error)
// until there are no more left. Aadditional requests fail with an error.
getAssignedPartitionsResponses []*proto.GetAssignedPartitionsResponse
getAssignedPartitionsResponseErrs []error
exceedsLimitsResponses []*proto.ExceedsLimitsResponse
exceedsLimitsResponseErrs []error

// The actual request counts.
numAssignedPartitionsRequests int
numExceedsLimitsRequests int
Comment on lines 56 to 57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking, no need for a mutex for these? no concurrent test cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No mutex required right now. Each test case uses a separate mock, and requests are executed in sequence.

}

func (m *mockIngestLimitsClient) GetAssignedPartitions(_ context.Context, r *proto.GetAssignedPartitionsRequest, _ ...grpc.CallOption) (*proto.GetAssignedPartitionsResponse, error) {
if expected := m.expectedAssignedPartitionsRequest; expected != nil {
require.Equal(m.t, expected, r)
func (m *mockIngestLimitsClient) GetAssignedPartitions(_ context.Context, _ *proto.GetAssignedPartitionsRequest, _ ...grpc.CallOption) (*proto.GetAssignedPartitionsResponse, error) {
idx := m.numAssignedPartitionsRequests
// Check that we haven't received more requests than we have mocked
// responses.
if idx >= len(m.getAssignedPartitionsResponses) {
return nil, errors.New("unexpected GetAssignedPartitionsRequest")
}
m.numAssignedPartitionsRequests++
if err := m.getAssignedPartitionsResponseErr; err != nil {
if err := m.getAssignedPartitionsResponseErrs[idx]; err != nil {
return nil, err
}
return m.getAssignedPartitionsResponse, nil
return m.getAssignedPartitionsResponses[idx], nil
}

func (m *mockIngestLimitsClient) ExceedsLimits(_ context.Context, r *proto.ExceedsLimitsRequest, _ ...grpc.CallOption) (*proto.ExceedsLimitsResponse, error) {
if expected := m.expectedExceedsLimitsRequest; expected != nil {
require.Equal(m.t, expected, r)
func (m *mockIngestLimitsClient) ExceedsLimits(_ context.Context, req *proto.ExceedsLimitsRequest, _ ...grpc.CallOption) (*proto.ExceedsLimitsResponse, error) {
idx := m.numExceedsLimitsRequests
// Check that we haven't received more requests than we have mocked
// responses.
if idx >= len(m.exceedsLimitsResponses) {
return nil, errors.New("unexpected ExceedsLimitsRequest")
}
m.numExceedsLimitsRequests++
if err := m.exceedsLimitsResponseErr; err != nil {
if len(m.expectedExceedsLimitsRequests) > 0 {
require.Equal(m.t, m.expectedExceedsLimitsRequests[idx], req)
}
if err := m.exceedsLimitsResponseErrs[idx]; err != nil {
return nil, err
}
return m.exceedsLimitsResponse, nil
return m.exceedsLimitsResponses[idx], nil
}

func (m *mockIngestLimitsClient) AssertExpectedNumRequests() {
require.Equal(m.t, m.expectedNumAssignedPartitionsRequests, m.numAssignedPartitionsRequests)
require.Equal(m.t, m.expectedNumExceedsLimitsRequests, m.numExceedsLimitsRequests)
func (m *mockIngestLimitsClient) Finished() {
require.Equal(m.t, len(m.getAssignedPartitionsResponses), m.numAssignedPartitionsRequests)
require.Equal(m.t, len(m.exceedsLimitsResponses), m.numExceedsLimitsRequests)
}

func (m *mockIngestLimitsClient) Close() error {
Expand Down
111 changes: 91 additions & 20 deletions pkg/limits/frontend/ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package frontend

import (
"context"
"slices"
"sort"
"strings"

"github.com/go-kit/log"
"github.com/go-kit/log/level"
Expand All @@ -19,6 +22,11 @@ const (

var (
LimitsRead = ring.NewOp([]ring.InstanceState{ring.ACTIVE}, nil)

// defaultZoneCmp compares two zones using [strings.Compare].
defaultZoneCmp = func(a, b string) int {
return strings.Compare(a, b)
}
)

// ringGatherer uses a ring to find limits instances.
Expand All @@ -28,6 +36,7 @@ type ringGatherer struct {
pool *ring_client.Pool
numPartitions int
assignedPartitionsCache cache[string, *proto.GetAssignedPartitionsResponse]
zoneCmp func(a, b string) int
}

// newRingGatherer returns a new ringGatherer.
Expand All @@ -44,10 +53,11 @@ func newRingGatherer(
pool: pool,
numPartitions: numPartitions,
assignedPartitionsCache: assignedPartitionsCache,
zoneCmp: defaultZoneCmp,
}
}

// ExceedsLimits implements the [exceedsLimitsGatherer] interface.
// exceedsLimits implements the [exceedsLimitsGatherer] interface.
func (g *ringGatherer) exceedsLimits(ctx context.Context, req *proto.ExceedsLimitsRequest) ([]*proto.ExceedsLimitsResponse, error) {
if len(req.Streams) == 0 {
return nil, nil
Expand All @@ -56,50 +66,111 @@ func (g *ringGatherer) exceedsLimits(ctx context.Context, req *proto.ExceedsLimi
if err != nil {
return nil, err
}
partitionConsumers, err := g.getPartitionConsumers(ctx, rs.Instances)
// Get the partition consumers for each zone.
zonesPartitions, err := g.getZoneAwarePartitionConsumers(ctx, rs.Instances)
if err != nil {
return nil, err
}
ownedStreams := make(map[string][]*proto.StreamMetadata)
for _, s := range req.Streams {
partition := int32(s.StreamHash % uint64(g.numPartitions))
addr, ok := partitionConsumers[partition]
// In practice we want zones to be queried in random order to spread
// reads. However, in tests we want a deterministic order so test cases
// are stable and reproducible. Having a custom sort func supports both
// use cases as zoneCmp can be switched out in tests.
zonesToQuery := make([]string, 0, len(zonesPartitions))
for zone := range zonesPartitions {
zonesToQuery = append(zonesToQuery, zone)
}
slices.SortFunc(zonesToQuery, g.zoneCmp)
// Make a copy of the streams from the request. We will prune this slice
// each time we receive the responses from a zone.
streams := make([]*proto.StreamMetadata, 0, len(req.Streams))
for _, stream := range req.Streams {
streams = append(streams, stream)
}
// Query each zone as ordered in zonesToQuery. If a zone answers all
// streams, the request is satisifed and there is no need to query
// subsequent zones. If a zone answers just a subset of streams
// (i.e. the instance that is consuming a partition is unavailable or the
// partition that owns one or more streams does not have a consumer)
// then query the next zone for the remaining streams. We repeat this
// process until all streams have been queried or we have exhausted all
// zones.
responses := make([]*proto.ExceedsLimitsResponse, 0)
for _, zone := range zonesToQuery {
resps, answered, err := g.doExceedsLimitsRPCs(ctx, req.Tenant, streams, zonesPartitions[zone])
if err != nil {
continue
}
responses = append(responses, resps...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit.

Suggested change
responses = append(responses, resps...)
responses = append(responses, resps...)
if len(streams) == len(answered) {
break
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to delete the streams, can't just break, as we will check len(streams) later to see if some streams went unanswered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to look at this in the next PR 👍

// Remove the answered streams from the slice. The slice of answered
// streams must be sorted so we can use sort.Search to subtract the
// two slices.
slices.Sort(answered)
streams = slices.DeleteFunc(streams, func(stream *proto.StreamMetadata) bool {
// see https://pkg.go.dev/sort#Search
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be O(N + logN)

i := sort.Search(len(answered), func(i int) bool {
return answered[i] >= stream.StreamHash
})
return i < len(answered) && answered[i] == stream.StreamHash
})
// All streams been checked against per-tenant limits.
if len(streams) == 0 {
break
}
}
// TODO(grobinson): In a subsequent change, I will figure out what to do
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present, unanswered streams are permitted. First I want to emit a metric to track this, and I will do this in a follow up PR to make the change easier to see.

// about unanswered streams.
return responses, nil
}

func (g *ringGatherer) doExceedsLimitsRPCs(ctx context.Context, tenant string, streams []*proto.StreamMetadata, partitions map[int32]string) ([]*proto.ExceedsLimitsResponse, []uint64, error) {
// For each stream, figure out which instance consume its partition.
instancesForStreams := make(map[string][]*proto.StreamMetadata)
for _, stream := range streams {
partition := int32(stream.StreamHash % uint64(g.numPartitions))
addr, ok := partitions[partition]
if !ok {
// TODO(grobinson): Drop streams when ok is false.
level.Warn(g.logger).Log("msg", "no instance found for partition", "partition", partition)
continue
}
ownedStreams[addr] = append(ownedStreams[addr], s)
instancesForStreams[addr] = append(instancesForStreams[addr], stream)
}
errg, ctx := errgroup.WithContext(ctx)
responseCh := make(chan *proto.ExceedsLimitsResponse, len(ownedStreams))
for addr, streams := range ownedStreams {
responseCh := make(chan *proto.ExceedsLimitsResponse, len(instancesForStreams))
answeredCh := make(chan uint64, len(streams))
for addr, streams := range instancesForStreams {
errg.Go(func() error {
client, err := g.pool.GetClientFor(addr)
if err != nil {
level.Error(g.logger).Log("msg", "failed to get client for instance", "instance", addr, "err", err.Error())
return err
return nil
}
resp, err := client.(proto.IngestLimitsClient).ExceedsLimits(ctx, &proto.ExceedsLimitsRequest{
Tenant: req.Tenant,
Tenant: tenant,
Streams: streams,
})
if err != nil {
return err
level.Error(g.logger).Log("failed check execeed limits for instance", "instance", addr, "err", err.Error())
return nil
}
responseCh <- resp
for _, stream := range streams {
answeredCh <- stream.StreamHash
}
return nil
})
}
if err = errg.Wait(); err != nil {
return nil, err
}
_ = errg.Wait()
close(responseCh)
responses := make([]*proto.ExceedsLimitsResponse, 0, len(rs.Instances))
for resp := range responseCh {
responses = append(responses, resp)
close(answeredCh)
responses := make([]*proto.ExceedsLimitsResponse, 0, len(instancesForStreams))
for r := range responseCh {
responses = append(responses, r)
}
return responses, nil
answered := make([]uint64, 0, len(streams))
for streamHash := range answeredCh {
answered = append(answered, streamHash)
}
return responses, answered, nil
}

type zonePartitionConsumersResult struct {
Expand Down
Loading
Loading