在两个单独的限速端点之间同步请求

问题描述 投票:0回答:1

我正在使用一些第三方 API,每个 API 都有自己的速率限制。端点 1 的速率限制为 10/s,端点 2 的速率限制为 20/s。

我需要通过端点 1 处理我的数据,它将返回一个对象数组(2-3000 个对象之间)。然后我需要获取每个对象并将一些数据发送到第二个端点,同时遵守第二个端点的速率限制。

我计划通过一次 10 个批次在 go 例程中发送第一个端点的请求,确保如果所有 10 个请求都在 <1s, I do not exceed by sending more within that 1 second window.

中完成

最终,我希望能够限制每个端点同时发出的并发响应数。特别是如果我必须为由于服务器 500+ 响应等导致的失败请求重试

出于问题的目的,我正在使用 httpbin 请求来模拟以下场景:

package main

import (
    "bytes"
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "sync"
    "time"
)

type HttpBinGetRequest struct {
    url string
}

type HttpBinGetResponse struct {
    Uuid       string `json:"uuid"`
    StatusCode int
}

type HttpBinPostRequest struct {
    url  string
    uuid string // Item to post to API
}

type HttpBinPostResponse struct {
    Data       string `json:"data"`
    StatusCode int
}

func main() {

    // Prepare GET requests for 500 requests
    var requests []*HttpBinGetRequest
    for i := 0; i < 500; i++ {
        uri := "https://httpbin.org/uuid"
        request := &HttpBinGetRequest{
            url: uri,
        }
        requests = append(requests, request)
    }

    // Create semaphore and rate limit for the GET endpoint
    getSemaphore := make(chan struct{}, 10)
    getRate := make(chan struct{}, 10)
    for i := 0; i < cap(getRate); i++ {
        getRate <- struct{}{}
    }

    go func() {
        // ticker corresponding to 1/10th of a second
        ticker := time.NewTicker(100 * time.Millisecond)
        defer ticker.Stop()
        for range ticker.C {
            _, ok := <-getRate
            if !ok {
                return
            }
        }
    }()

    // Send our GET requests to obtain a random UUID
    var wg sync.WaitGroup
    for _, request := range requests {
        wg.Add(1)
        // Go func to make request and receive the response
        go func(r *HttpBinGetRequest) {
            defer wg.Done()

            // Check the rate limiter and block if it is empty
            getRate <- struct{}{}

            // Add a token to the semaphore
            getSemaphore <- struct{}{}

            // Remove token when function is complete
            defer func() {
                <-getSemaphore
            }()
            resp, _ := get(r)
            fmt.Printf("%+v\n", resp)
        }(request)
    }
    wg.Wait()

    // I need to add code that obtains the response data from the above for loop
    // then sends the UUID it to its own go routines for a POST request, following a similar pattern above
    // To not violate the rate limit of the second endpoint which is 20 calls per second
    // postSemaphore := make(chan struct{}, 20)
    // postRate := make(chan struct{}, 20)
    // for i := 0; i < cap(postRate); i++ {
    //  postRate <- struct{}{}
    // }
}

func get(hbgr *HttpBinGetRequest) (*HttpBinGetResponse, error) {

    httpResp := &HttpBinGetResponse{}
    client := &http.Client{}
    req, err := http.NewRequest("GET", hbgr.url, nil)
    if err != nil {
        fmt.Println("error making request")
        return httpResp, err
    }

    req.Header = http.Header{
        "accept": {"application/json"},
    }

    resp, err := client.Do(req)
    if err != nil {
        fmt.Println(err)
        fmt.Println("error getting response")
        return httpResp, err
    }

    // Read Response
    body, err := io.ReadAll(resp.Body)
    if err != nil {
        fmt.Println("error reading response body")
        return httpResp, err
    }
    json.Unmarshal(body, &httpResp)
    httpResp.StatusCode = resp.StatusCode
    return httpResp, nil
}

// Method to post data to httpbin
func post(hbr *HttpBinPostRequest) (*HttpBinPostResponse, error) {

    httpResp := &HttpBinPostResponse{}
    client := &http.Client{}
    req, err := http.NewRequest("POST", hbr.url, bytes.NewBuffer([]byte(hbr.uuid)))
    if err != nil {
        fmt.Println("error making request")
        return httpResp, err
    }

    req.Header = http.Header{
        "accept": {"application/json"},
    }

    resp, err := client.Do(req)
    if err != nil {
        fmt.Println("error getting response")
        return httpResp, err
    }

    if resp.StatusCode == 429 {
        fmt.Println(resp.Header.Get("Retry-After"))
    }

    // Read Response
    body, err := io.ReadAll(resp.Body)
    if err != nil {
        fmt.Println("error reading response body")
        return httpResp, err
    }
    json.Unmarshal(body, &httpResp)
    httpResp.StatusCode = resp.StatusCode
    fmt.Printf("%+v", httpResp)
    return httpResp, nil
}
go semaphore rate-limiting
1个回答
0
投票

这是生产者/消费者模式。您可以使用 chan 来连接它们。

关于速率限制器,我会使用包

golang.org/x/time/rate
.

我会使用包

golang.org/x/sync/errgroup
来限制goroutines的数量。

我不确定这是否是最佳实践。欢迎评论!

演示来了:

package main

import (
    "context"
    "errors"
    "fmt"
    "io"
    "log"
    "math/rand"
    "net/http"
    "net/http/httptest"
    "sort"
    "sync"
    "time"

    "golang.org/x/sync/errgroup"
    "golang.org/x/time/rate"
)

func main() {
    s := &server{}
    ts := httptest.NewServer(s)
    defer ts.Close()

    // The get action is a producer, and the post action is a consumer.
    // Choose a chan size according to the producing/consuming speed.
    ch := make(chan string, 10)

    ctx := context.Background()

    maxGet, maxPost := 10, 20
    retryCount := 3

    g1 := new(errgroup.Group)
    // g1 is used to limit the number of gorontines for the get actions.
    // choose the number according to the latency of the get request.
    g1.SetLimit(maxGet * 2)
    lim1 := rate.NewLimiter(rate.Limit(maxGet), 1)

    go func() {
        for i := 0; i < 50; i++ {
            i := i
            g1.Go(func() error {
                for j := 0; j < retryCount; j++ {
                    id, err := get(ctx, lim1, fmt.Sprintf("%s/%d", ts.URL, i))
                    if err != nil {
                        if errors.Is(err, context.Canceled) {
                            return err
                        }
                        log.Printf("get %d: %v", i, err)
                        continue
                    }

                    ch <- id
                    return nil
                }
                log.Printf("get request %d failed after %d retries", i, retryCount)
                return nil
            })
        }
        g1.Wait()
        close(ch)
    }()

    g2 := new(errgroup.Group)
    // g1 is used to limit the number of gorontines for the post actions.
    // choose the number according to the latency of the post request.
    g2.SetLimit(maxPost * 3)
    lim2 := rate.NewLimiter(rate.Limit(maxPost), 1)

    for id := range ch {
        id := id
        g2.Go(func() error {
            for j := 0; j < retryCount; j++ {
                err := post(ctx, lim2, fmt.Sprintf("%s/%s", ts.URL, id))
                if err != nil {
                    if errors.Is(err, context.Canceled) {
                        return err
                    }
                    log.Printf("post: %v", err)
                    continue
                }

                return nil
            }
            log.Printf("post request %s failed after %d retries", id, retryCount)
            return nil
        })
    }
    g2.Wait()

    s.printStats()
}

func get(ctx context.Context, lim *rate.Limiter, url string) (string, error) {
    if err := lim.Wait(ctx); err != nil {
        return "", err
    }

    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return "", err
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()

    if resp.StatusCode != 200 {
        return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
    }

    body, err := io.ReadAll(resp.Body)
    if err != nil {
        return "", err
    }

    return string(body), nil
}

func post(ctx context.Context, lim *rate.Limiter, url string) error {
    if err := lim.Wait(ctx); err != nil {
        return err
    }

    req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
    if err != nil {
        return err
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return err
    }
    defer resp.Body.Close()

    if resp.StatusCode != 200 {
        return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
    }

    return nil
}

type server struct {
    gMu  sync.Mutex
    gets []int64

    pMu   sync.Mutex
    posts []int64
}

func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    log.Printf("%s: %s", r.Method, r.URL.Path)

    // collect request stats.
    if r.Method == http.MethodGet {
        s.gMu.Lock()
        s.gets = append(s.gets, time.Now().UnixMilli())
        s.gMu.Unlock()
    } else {
        s.pMu.Lock()
        s.posts = append(s.posts, time.Now().UnixMilli())
        s.pMu.Unlock()
    }

    n := rand.Intn(1000)
    // simulate latency.
    time.Sleep(time.Duration(n) * time.Millisecond)

    // simulate errors.
    if n%10 == 0 {
        w.WriteHeader(http.StatusInternalServerError)
        return
    }

    if r.Method == http.MethodGet {
        fmt.Fprintf(w, "%s", r.URL.Path[1:])
        return
    }
}

func (s *server) printStats() {
    log.Printf("GETS (total: %d):\n", len(s.gets))
    printStats(s.gets)
    log.Printf("POSTS (total: %d):\n", len(s.posts))
    printStats(s.posts)
}

func printStats(ts []int64) {
    sort.Slice(ts, func(i, j int) bool {
        return ts[i] < ts[j]
    })

    count := 0
    to := ts[0] + 1000
    for i := 0; i < len(ts); i++ {
        if ts[i] < to {
            count++
        } else {
            fmt.Printf("  %d: %d\n", to, count)
            i-- // push back the current item
            count = 0
            to += 1000
        }
    }
    if count > 0 {
        fmt.Printf("  %d: %d\n", to, count)
    }
}

输出如下所示:

...
2023/03/25 14:59:12 GETS (total: 56):
  1679727546667: 10
  1679727547667: 10
  1679727548667: 10
  1679727549667: 10
  1679727550667: 10
  1679727551667: 5
  1679727552667: 1
2023/03/25 14:59:12 POSTS (total: 55):
  1679727546749: 8
  1679727547749: 8
  1679727548749: 12
  1679727549749: 8
  1679727550749: 10
  1679727551749: 6
  1679727552749: 3
© www.soinside.com 2019 - 2024. All rights reserved.