我正在使用一些第三方 API,每个 API 都有自己的速率限制。端点 1 的速率限制为 10/s,端点 2 的速率限制为 20/s。
我需要通过端点 1 处理我的数据,它将返回一个对象数组(2-3000 个对象之间)。然后我需要获取每个对象并将一些数据发送到第二个端点,同时遵守第二个端点的速率限制。
我计划通过一次 10 个批次在 go 例程中发送第一个端点的请求,确保如果所有 10 个请求都在 <1s, I do not exceed by sending more within that 1 second window.
中完成最终,我希望能够限制每个端点同时发出的并发响应数。特别是如果我必须为由于服务器 500+ 响应等导致的失败请求重试
出于问题的目的,我正在使用 httpbin 请求来模拟以下场景:
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"time"
)
type HttpBinGetRequest struct {
url string
}
type HttpBinGetResponse struct {
Uuid string `json:"uuid"`
StatusCode int
}
type HttpBinPostRequest struct {
url string
uuid string // Item to post to API
}
type HttpBinPostResponse struct {
Data string `json:"data"`
StatusCode int
}
func main() {
// Prepare GET requests for 500 requests
var requests []*HttpBinGetRequest
for i := 0; i < 500; i++ {
uri := "https://httpbin.org/uuid"
request := &HttpBinGetRequest{
url: uri,
}
requests = append(requests, request)
}
// Create semaphore and rate limit for the GET endpoint
getSemaphore := make(chan struct{}, 10)
getRate := make(chan struct{}, 10)
for i := 0; i < cap(getRate); i++ {
getRate <- struct{}{}
}
go func() {
// ticker corresponding to 1/10th of a second
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
_, ok := <-getRate
if !ok {
return
}
}
}()
// Send our GET requests to obtain a random UUID
var wg sync.WaitGroup
for _, request := range requests {
wg.Add(1)
// Go func to make request and receive the response
go func(r *HttpBinGetRequest) {
defer wg.Done()
// Check the rate limiter and block if it is empty
getRate <- struct{}{}
// Add a token to the semaphore
getSemaphore <- struct{}{}
// Remove token when function is complete
defer func() {
<-getSemaphore
}()
resp, _ := get(r)
fmt.Printf("%+v\n", resp)
}(request)
}
wg.Wait()
// I need to add code that obtains the response data from the above for loop
// then sends the UUID it to its own go routines for a POST request, following a similar pattern above
// To not violate the rate limit of the second endpoint which is 20 calls per second
// postSemaphore := make(chan struct{}, 20)
// postRate := make(chan struct{}, 20)
// for i := 0; i < cap(postRate); i++ {
// postRate <- struct{}{}
// }
}
func get(hbgr *HttpBinGetRequest) (*HttpBinGetResponse, error) {
httpResp := &HttpBinGetResponse{}
client := &http.Client{}
req, err := http.NewRequest("GET", hbgr.url, nil)
if err != nil {
fmt.Println("error making request")
return httpResp, err
}
req.Header = http.Header{
"accept": {"application/json"},
}
resp, err := client.Do(req)
if err != nil {
fmt.Println(err)
fmt.Println("error getting response")
return httpResp, err
}
// Read Response
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("error reading response body")
return httpResp, err
}
json.Unmarshal(body, &httpResp)
httpResp.StatusCode = resp.StatusCode
return httpResp, nil
}
// Method to post data to httpbin
func post(hbr *HttpBinPostRequest) (*HttpBinPostResponse, error) {
httpResp := &HttpBinPostResponse{}
client := &http.Client{}
req, err := http.NewRequest("POST", hbr.url, bytes.NewBuffer([]byte(hbr.uuid)))
if err != nil {
fmt.Println("error making request")
return httpResp, err
}
req.Header = http.Header{
"accept": {"application/json"},
}
resp, err := client.Do(req)
if err != nil {
fmt.Println("error getting response")
return httpResp, err
}
if resp.StatusCode == 429 {
fmt.Println(resp.Header.Get("Retry-After"))
}
// Read Response
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("error reading response body")
return httpResp, err
}
json.Unmarshal(body, &httpResp)
httpResp.StatusCode = resp.StatusCode
fmt.Printf("%+v", httpResp)
return httpResp, nil
}
这是生产者/消费者模式。您可以使用 chan 来连接它们。
关于速率限制器,我会使用包
golang.org/x/time/rate
.
我会使用包
golang.org/x/sync/errgroup
来限制goroutines的数量。
我不确定这是否是最佳实践。欢迎评论!
演示来了:
package main
import (
"context"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"net/http/httptest"
"sort"
"sync"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/time/rate"
)
func main() {
s := &server{}
ts := httptest.NewServer(s)
defer ts.Close()
// The get action is a producer, and the post action is a consumer.
// Choose a chan size according to the producing/consuming speed.
ch := make(chan string, 10)
ctx := context.Background()
maxGet, maxPost := 10, 20
retryCount := 3
g1 := new(errgroup.Group)
// g1 is used to limit the number of gorontines for the get actions.
// choose the number according to the latency of the get request.
g1.SetLimit(maxGet * 2)
lim1 := rate.NewLimiter(rate.Limit(maxGet), 1)
go func() {
for i := 0; i < 50; i++ {
i := i
g1.Go(func() error {
for j := 0; j < retryCount; j++ {
id, err := get(ctx, lim1, fmt.Sprintf("%s/%d", ts.URL, i))
if err != nil {
if errors.Is(err, context.Canceled) {
return err
}
log.Printf("get %d: %v", i, err)
continue
}
ch <- id
return nil
}
log.Printf("get request %d failed after %d retries", i, retryCount)
return nil
})
}
g1.Wait()
close(ch)
}()
g2 := new(errgroup.Group)
// g1 is used to limit the number of gorontines for the post actions.
// choose the number according to the latency of the post request.
g2.SetLimit(maxPost * 3)
lim2 := rate.NewLimiter(rate.Limit(maxPost), 1)
for id := range ch {
id := id
g2.Go(func() error {
for j := 0; j < retryCount; j++ {
err := post(ctx, lim2, fmt.Sprintf("%s/%s", ts.URL, id))
if err != nil {
if errors.Is(err, context.Canceled) {
return err
}
log.Printf("post: %v", err)
continue
}
return nil
}
log.Printf("post request %s failed after %d retries", id, retryCount)
return nil
})
}
g2.Wait()
s.printStats()
}
func get(ctx context.Context, lim *rate.Limiter, url string) (string, error) {
if err := lim.Wait(ctx); err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
func post(ctx context.Context, lim *rate.Limiter, url string) error {
if err := lim.Wait(ctx); err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}
type server struct {
gMu sync.Mutex
gets []int64
pMu sync.Mutex
posts []int64
}
func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("%s: %s", r.Method, r.URL.Path)
// collect request stats.
if r.Method == http.MethodGet {
s.gMu.Lock()
s.gets = append(s.gets, time.Now().UnixMilli())
s.gMu.Unlock()
} else {
s.pMu.Lock()
s.posts = append(s.posts, time.Now().UnixMilli())
s.pMu.Unlock()
}
n := rand.Intn(1000)
// simulate latency.
time.Sleep(time.Duration(n) * time.Millisecond)
// simulate errors.
if n%10 == 0 {
w.WriteHeader(http.StatusInternalServerError)
return
}
if r.Method == http.MethodGet {
fmt.Fprintf(w, "%s", r.URL.Path[1:])
return
}
}
func (s *server) printStats() {
log.Printf("GETS (total: %d):\n", len(s.gets))
printStats(s.gets)
log.Printf("POSTS (total: %d):\n", len(s.posts))
printStats(s.posts)
}
func printStats(ts []int64) {
sort.Slice(ts, func(i, j int) bool {
return ts[i] < ts[j]
})
count := 0
to := ts[0] + 1000
for i := 0; i < len(ts); i++ {
if ts[i] < to {
count++
} else {
fmt.Printf(" %d: %d\n", to, count)
i-- // push back the current item
count = 0
to += 1000
}
}
if count > 0 {
fmt.Printf(" %d: %d\n", to, count)
}
}
输出如下所示:
...
2023/03/25 14:59:12 GETS (total: 56):
1679727546667: 10
1679727547667: 10
1679727548667: 10
1679727549667: 10
1679727550667: 10
1679727551667: 5
1679727552667: 1
2023/03/25 14:59:12 POSTS (total: 55):
1679727546749: 8
1679727547749: 8
1679727548749: 12
1679727549749: 8
1679727550749: 10
1679727551749: 6
1679727552749: 3