Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
}

// Construct REST client
rateLimitState := transport.NewRateLimitState()
rateLimitLogger := cfg.Logger.With("component", "rate_limit")

restUATransport := &transport.UserAgentTransport{
Transport: http.DefaultTransport,
Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger),
Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version),
}
restClient, err := gogithub.NewClient(
Expand All @@ -80,7 +83,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
gqlHTTPClient := &http.Client{
Transport: &transport.BearerAuthTransport{
Transport: &transport.GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger),
},
Token: cfg.Token,
},
Expand Down
14 changes: 12 additions & 2 deletions pkg/github/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ type RequestDeps struct {

// Observability exporters (includes logger)
obsv observability.Exporters

rateLimits *transport.RateLimitRegistry
}

// NewRequestDeps creates a RequestDeps with the provided clients and configuration.
Expand All @@ -298,6 +300,7 @@ func NewRequestDeps(
ContentWindowSize: contentWindowSize,
featureChecker: featureChecker,
obsv: obsv,
rateLimits: transport.NewRateLimitRegistry(),
}
}

Expand All @@ -320,9 +323,15 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) {
}

// Construct REST client
rateLimitLogger := d.obsv.Logger().With("component", "rate_limit")
restClient, err := gogithub.NewClient(
gogithub.WithHTTPClient(&http.Client{
Transport: &transport.UserAgentTransport{
Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger),
Agent: fmt.Sprintf("github-mcp-server/%s", d.version),
},
}),
gogithub.WithAuthToken(token),
gogithub.WithUserAgent(fmt.Sprintf("github-mcp-server/%s", d.version)),
gogithub.WithEnterpriseURLs(baseRestURL.String(), uploadURL.String()),
)
if err != nil {
Expand All @@ -344,10 +353,11 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error
// We use NewEnterpriseClient unconditionally since we already parsed the API host
// Wrap transport with GraphQLFeaturesTransport to inject feature flags from context,
// matching the transport chain used by the remote server.
rateLimitLogger := d.obsv.Logger().With("component", "rate_limit")
gqlHTTPClient := &http.Client{
Transport: &transport.BearerAuthTransport{
Transport: &transport.GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger),
},
Token: token,
},
Expand Down
39 changes: 39 additions & 0 deletions pkg/github/dependencies_ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package github

import (
"log/slog"
"testing"

"github.com/github/github-mcp-server/pkg/observability"
"github.com/github/github-mcp-server/pkg/observability/metrics"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewRequestDeps_InitializesRateLimitRegistry(t *testing.T) {
t.Parallel()

obs, err := observability.NewExporters(slog.New(slog.DiscardHandler), metrics.NewNoopMetrics())
require.NoError(t, err)

deps := NewRequestDeps(
nil,
"test",
false,
nil,
translations.NullTranslationHelper,
0,
nil,
obs,
)

require.NotNil(t, deps.rateLimits)

stateA1 := deps.rateLimits.Get("token-a")
stateA2 := deps.rateLimits.Get("token-a")
stateB := deps.rateLimits.Get("token-b")

assert.Same(t, stateA1, stateA2)
assert.NotSame(t, stateA1, stateB)
}
230 changes: 230 additions & 0 deletions pkg/http/transport/rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package transport

import (
"context"
"log/slog"
"net/http"
"strconv"
"sync"
"time"
)

const (
DefaultMinRateLimitRemaining = 50
DefaultMinRequestInterval = 50 * time.Millisecond
DefaultMaxRateLimitRetries = 3
)

type RateLimitState struct {
mu sync.Mutex

remaining int // -1 means unknown
reset time.Time
lastReq time.Time
}

func NewRateLimitState() *RateLimitState {
return &RateLimitState{remaining: -1}
}

type RateLimitRegistry struct {
states sync.Map
}

func NewRateLimitRegistry() *RateLimitRegistry {
return &RateLimitRegistry{}
}

func (r *RateLimitRegistry) Get(token string) *RateLimitState {
if state, ok := r.states.Load(token); ok {
return state.(*RateLimitState)
}

state := NewRateLimitState()
actual, _ := r.states.LoadOrStore(token, state)
return actual.(*RateLimitState)
}

type RateLimitTransport struct {
Transport http.RoundTripper
State *RateLimitState

MinInterval time.Duration
MinRemaining int
MaxRetries int
Logger *slog.Logger
}

func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState, logger *slog.Logger) http.RoundTripper {
if state == nil {
state = NewRateLimitState()
}

return &RateLimitTransport{
Transport: base,
State: state,
MinInterval: DefaultMinRequestInterval,
MinRemaining: DefaultMinRateLimitRemaining,
MaxRetries: DefaultMaxRateLimitRetries,
Logger: logger,
}
}

func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

maxRetries := t.MaxRetries
if maxRetries < 0 {
maxRetries = DefaultMaxRateLimitRetries
}

for attempt := 0; attempt <= maxRetries; attempt++ {
t.waitBeforeRequest(req.Context())

resp, err := transport.RoundTrip(req)
if err != nil {
return resp, err
}

t.updateFromResponse(resp)

if !isRateLimitedResponse(resp) || attempt == maxRetries {
return resp, nil
}

wait := retryAfterDuration(resp)
if t.Logger != nil {
t.Logger.Warn(
"GitHub API rate limit hit, waiting before retry",
"attempt", attempt+1,
"max_retries", maxRetries,
"wait", wait.Round(time.Second),
"status", resp.StatusCode,
)
}

resp.Body.Close()
waitForContext(req.Context(), wait)
}

return nil, nil
}

func (t *RateLimitTransport) waitBeforeRequest(ctx context.Context) {
minInterval := t.MinInterval
if minInterval <= 0 {
minInterval = DefaultMinRequestInterval
}

minRemaining := t.MinRemaining
if minRemaining <= 0 {
minRemaining = DefaultMinRateLimitRemaining
}

t.State.mu.Lock()
defer t.State.mu.Unlock()

if wait := time.Until(t.State.lastReq.Add(minInterval)); wait > 0 {
waitForContext(ctx, wait)
}

if t.State.remaining >= 0 && t.State.remaining < minRemaining && !t.State.reset.IsZero() {
if wait := time.Until(t.State.reset) + time.Second; wait > 0 {
if t.Logger != nil {
t.Logger.Warn(
"GitHub API rate limit nearly exhausted, waiting for reset",
"remaining", t.State.remaining,
"wait", wait.Round(time.Second),
)
}
waitForContext(ctx, wait)
t.State.remaining = -1
}
}

t.State.lastReq = time.Now()
}

func (t *RateLimitTransport) updateFromResponse(resp *http.Response) {
remaining, reset, ok := parseRateLimitHeaders(resp)
if !ok {
return
}

t.State.mu.Lock()
defer t.State.mu.Unlock()
t.State.remaining = remaining
t.State.reset = reset
}

func parseRateLimitHeaders(resp *http.Response) (remaining int, reset time.Time, ok bool) {
remainingStr := resp.Header.Get("X-RateLimit-Remaining")
resetStr := resp.Header.Get("X-RateLimit-Reset")
if remainingStr == "" || resetStr == "" {
return 0, time.Time{}, false
}

remainingVal, err := strconv.Atoi(remainingStr)
if err != nil {
return 0, time.Time{}, false
}

resetUnix, err := strconv.ParseInt(resetStr, 10, 64)
if err != nil {
return 0, time.Time{}, false
}

return remainingVal, time.Unix(resetUnix, 0), true
}

func isRateLimitedResponse(resp *http.Response) bool {
if resp == nil {
return false
}

switch resp.StatusCode {
case http.StatusTooManyRequests:
return true
case http.StatusForbidden:
return resp.Header.Get("Retry-After") != ""
default:
return false
}
}

func retryAfterDuration(resp *http.Response) time.Duration {
if resp == nil {
return time.Second
}

if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" {
if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}

if _, reset, ok := parseRateLimitHeaders(resp); ok && !reset.IsZero() {
if wait := time.Until(reset) + time.Second; wait > 0 {
return wait
}
}

return time.Second
}

func waitForContext(ctx context.Context, d time.Duration) {
if d <= 0 {
return
}

timer := time.NewTimer(d)
defer timer.Stop()

select {
case <-ctx.Done():
case <-timer.C:
}
}
Loading