Commit 4477f62c authored by Mitar's avatar Mitar
Browse files

Improve RetryableResponse use from multiple gothreads.

Especially that you can read from one gothread and close it from another.
parent 23f6e54e
Pipeline #441971220 passed with stages
in 2 minutes and 17 seconds
......@@ -6,6 +6,8 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/hashicorp/go-retryablehttp"
"gitlab.com/tozd/go/errors"
......@@ -24,6 +26,7 @@ type RetryableResponse struct {
req *retryablehttp.Request
count int64
size int64
lock sync.Mutex
*http.Response
}
......@@ -32,22 +35,32 @@ type RetryableResponse struct {
// Use this to read the response
// body and not RetryableResponse.Response.Body.Read.
func (d *RetryableResponse) Read(p []byte) (int, error) {
n, err := d.Response.Body.Read(p)
d.count += int64(n)
if d.count == d.size {
d.lock.Lock()
resp := d.Response
d.lock.Unlock()
if resp == nil {
return 0, errors.New("response already closed")
}
n, err := resp.Body.Read(p)
count := atomic.AddInt64(&d.count, int64(n))
size := d.Size()
if count == size {
// We read everything, just return as-is.
return n, err
} else if d.count > d.size {
} else if count > size {
if err != nil {
return n, errors.Wrapf(err, "read beyond the expected end of the response body (%d vs. %d)", d.count, d.size)
return n, errors.Wrapf(err, "read beyond the expected end of the response body (%d vs. %d)", count, size)
}
return n, errors.Errorf("read beyond the expected end of the response body (%d vs. %d)", d.count, d.size)
return n, errors.Errorf("read beyond the expected end of the response body (%d vs. %d)", count, size)
} else if contextErr := d.req.Context().Err(); contextErr != nil {
// Do not retry on context.Canceled or context.DeadlineExceeded.
return n, contextErr
} else if err != nil {
// We have not read everything, but we got an error. We retry.
errStart := d.start(d.count)
errStart := d.start()
if errStart != nil {
return n, errStart
}
......@@ -65,30 +78,36 @@ func (d *RetryableResponse) Read(p []byte) (int, error) {
//
// It returns the number of bytes read until now.
func (d *RetryableResponse) Count() int64 {
return d.size
return atomic.LoadInt64(&d.count)
}
// Size returns the expected number of bytes to read.
func (d *RetryableResponse) Size() int64 {
return d.size
return atomic.LoadInt64(&d.size)
}
// Close implements io.Closer interface for RetryableResponse.
//
// It closes the underlying response body.
func (d *RetryableResponse) Close() error {
if d.Response != nil {
err := errors.WithStack(d.Response.Body.Close())
d.Response = nil
return err
d.lock.Lock()
resp := d.Response
d.Response = nil
d.lock.Unlock()
if resp != nil {
return errors.WithStack(resp.Body.Close())
}
return nil
}
func (d *RetryableResponse) start(from int64) errors.E {
func (d *RetryableResponse) start() errors.E {
d.Close()
if from > 0 {
d.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", from))
count := d.Count()
if count > 0 {
d.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", count))
} else {
d.req.Header.Del("Range")
}
......@@ -96,11 +115,10 @@ func (d *RetryableResponse) start(from int64) errors.E {
if err != nil {
return errors.WithStack(err)
}
if (from > 0 && resp.StatusCode != http.StatusPartialContent) || (from <= 0 && resp.StatusCode != http.StatusOK) {
if (count > 0 && resp.StatusCode != http.StatusPartialContent) || (count <= 0 && resp.StatusCode != http.StatusOK) {
body, _ := io.ReadAll(resp.Body)
return errors.Errorf("bad response status (%s): %s", resp.Status, strings.TrimSpace(string(body)))
}
d.Response = resp
lengthStr := resp.Header.Get("Content-Length")
if lengthStr == "" {
return errors.Errorf("missing Content-Length header in response")
......@@ -109,13 +127,20 @@ func (d *RetryableResponse) start(from int64) errors.E {
if err != nil {
return errors.WithStack(err)
}
if from > 0 {
if d.count+length != d.size {
return errors.Errorf("content after retry has different length (%d) than before (%d)", d.count+length, d.size)
size := d.Size()
if count > 0 {
if count+length != size {
return errors.Errorf("content after retry has different length (%d) than before (%d)", count+length, size)
}
} else {
d.size = length
atomic.StoreInt64(&d.size, length)
}
d.lock.Lock()
d.Response = resp
d.lock.Unlock()
return nil
}
......@@ -126,9 +151,10 @@ func NewRetryableResponse(client *retryablehttp.Client, req *retryablehttp.Reque
req: req,
count: 0,
size: 0,
lock: sync.Mutex{},
Response: nil,
}
err := r.start(0)
err := r.start()
if err != nil {
return nil, err
}
......
......@@ -35,6 +35,7 @@ func TestRetryableResponseSimple(t *testing.T) {
res, err := x.NewRetryableResponse(client, req)
require.NoError(t, err)
require.NotNil(t, res)
defer res.Close()
assert.Equal(t, "14", res.Header.Get("Content-Length"))
assert.Equal(t, int64(14), res.Size())
......@@ -80,6 +81,7 @@ func TestRetryableResponseRetry(t *testing.T) {
res, err := x.NewRetryableResponse(client, req)
require.NoError(t, err)
require.NotNil(t, res)
defer res.Close()
assert.Equal(t, "14", res.Header.Get("Content-Length"))
assert.Equal(t, int64(14), res.Size())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment