Commit 0e47b69e authored by Mitar's avatar Mitar
Browse files

Add more features.

parent 349f048c
Pipeline #441204560 passed with stages
in 2 minutes and 14 seconds
......@@ -9,6 +9,8 @@ require (
gitlab.com/tozd/go/errors v0.3.0
)
require github.com/hashicorp/go-cleanhttp v0.5.1 // indirect
require (
github.com/Microsoft/go-winio v0.4.16 // indirect
github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7 // indirect
......@@ -17,6 +19,7 @@ require (
github.com/emirpasic/gods v1.12.0 // indirect
github.com/go-git/gcfg v1.5.0 // indirect
github.com/go-git/go-billy/v5 v5.3.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1-0.20211018174820-ff6d014e72d9
github.com/imdario/mergo v0.3.12 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/kevinburke/ssh_config v0.0.0-20201106050909-4977a11b4351 // indirect
......
package x
import (
"bytes"
"encoding/json"
"gitlab.com/tozd/go/errors"
)
// UnmarshalWithoutUnknownFields is a standard JSON unmarshal, just
// that it returns an error if there is any unknown field present in JSON.
func UnmarshalWithoutUnknownFields(data []byte, v interface{}) errors.E {
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.DisallowUnknownFields()
err := decoder.Decode(v)
if err != nil {
return errors.WithStack(err)
}
return nil
}
package x_test
import (
"testing"
"github.com/stretchr/testify/assert"
"gitlab.com/tozd/go/x"
)
func TestUnmarshalWithoutUnknownFields(t *testing.T) {
type Test struct {
Field string
}
var v Test
err := x.UnmarshalWithoutUnknownFields([]byte(`{}`), &v)
assert.NoError(t, err)
err = x.UnmarshalWithoutUnknownFields([]byte(`{"field": "abc"}`), &v)
assert.NoError(t, err)
err = x.UnmarshalWithoutUnknownFields([]byte(`{"field2": "abc"}`), &v)
assert.Error(t, err)
}
package x
import (
"context"
"io"
"sync/atomic"
"time"
)
// CountingReader is an io.Reader proxy which counts the number of bytes
// it read and passed on.
type CountingReader struct {
Reader io.Reader
count int64
}
// NewCountingReader returns a new CountingReader which reads
// from the reader and counts the bytes.
func NewCountingReader(reader io.Reader) *CountingReader {
return &CountingReader{
Reader: reader,
count: 0,
}
}
// Read implements io.Reader interface for CountingReader.
func (c *CountingReader) Read(p []byte) (int, error) {
n, err := c.Reader.Read(p)
atomic.AddInt64(&c.count, int64(n))
return n, err
}
// Count implements counter interface for CountingReader.
//
// It returns the number of bytes read until now.
func (c *CountingReader) Count() int64 {
return atomic.LoadInt64(&c.count)
}
type counter interface {
Count() int64
}
// Progress describes current progress as reported by the counter.
type Progress struct {
Count int64
Size int64
Started time.Time
Current time.Time
Elapsed time.Duration
remaining time.Duration
estimated time.Time
}
func (p Progress) Percent() float64 {
return float64(p.Count) / float64(p.Size) * 100.0 //nolint:gomnd
}
func (p Progress) Remaining() time.Duration {
return p.remaining
}
func (p Progress) Estimated() time.Time {
return p.estimated
}
type Ticker struct {
C <-chan Progress
stop func()
}
// Stop stops the ticker and frees resources.
func (t *Ticker) Stop() {
t.stop()
}
// NewTicker creates a new Ticker which at regular interval reports the
// progress as reported by the counter c.
func NewTicker(ctx context.Context, c counter, size int64, interval time.Duration) *Ticker {
ctx, cancel := context.WithCancel(ctx)
started := time.Now()
output := make(chan Progress)
ticker := time.NewTicker(interval)
go func() {
defer cancel()
defer close(output)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case now := <-ticker.C:
count := c.Count()
elapsed := now.Sub(started)
ratio := float64(count) / float64(size)
total := time.Duration(float64(elapsed) / ratio)
estimated := started.Add(total)
progress := Progress{
Count: count,
Size: size,
Started: started,
Current: now,
Elapsed: elapsed,
remaining: estimated.Sub(now),
estimated: estimated,
}
if ctx.Err() != nil {
return
}
output <- progress
}
}
}()
return &Ticker{
C: output,
stop: cancel,
}
}
package x_test
import (
"context"
"io"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/tozd/go/x"
)
const (
tickerInterval = 50 * time.Millisecond
)
func TestTicker(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, w := io.Pipe()
defer r.Close()
defer w.Close()
countingReader := x.NewCountingReader(r)
go func() {
_, _ = io.ReadAll(countingReader)
}()
ticker := x.NewTicker(ctx, countingReader, 10, tickerInterval)
require.NotNil(t, ticker)
defer ticker.Stop()
l := sync.RWMutex{}
progress := []x.Progress{}
go func() {
for p := range ticker.C {
func() {
l.Lock()
defer l.Unlock()
progress = append(progress, p)
}()
}
}()
time.Sleep(2 * tickerInterval)
var p x.Progress
func() {
l.Lock()
defer l.Unlock()
require.NotEmpty(t, progress)
p = progress[len(progress)-1]
}()
assert.Equal(t, int64(10), p.Size)
assert.Equal(t, int64(0), p.Count)
assert.Equal(t, 0.0, p.Percent())
n, err := w.Write([]byte("abcd"))
assert.Equal(t, 4, n)
assert.NoError(t, err)
time.Sleep(2 * tickerInterval)
func() {
l.Lock()
defer l.Unlock()
require.NotEmpty(t, progress)
p = progress[len(progress)-1]
}()
assert.Equal(t, int64(10), p.Size)
assert.Equal(t, int64(4), p.Count)
assert.Equal(t, 40.0, p.Percent())
cancel()
// We give time for cancel to propagate.
time.Sleep(2 * tickerInterval)
var progressLen int
func() {
l.Lock()
defer l.Unlock()
progressLen = len(progress)
}()
// After this there should be no new progress added.
time.Sleep(2 * tickerInterval)
func() {
l.Lock()
defer l.Unlock()
assert.Equal(t, progressLen, len(progress))
}()
// Channel should be closed.
select {
case _, ok := <-ticker.C:
if ok {
require.Fail(t, "progress where there should be none")
}
default:
}
}
package x
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/hashicorp/go-retryablehttp"
"gitlab.com/tozd/go/errors"
)
// RetryableResponse reads the response body until it is completely read.
//
// If reading fails before full contents have been read
// (based on the Content-Length header), it transparently retries the request
// using Range request header and continues reading the new response body.
//
// It embeds the current response (so you can access response headers, etc.)
// but the current response can change when the request is retried.
type RetryableResponse struct {
client *retryablehttp.Client
req *retryablehttp.Request
count int64
size int64
*http.Response
}
// Read implements io.Reader for RetryableResponse.
//
// Use this to read the response
// body and not RetryableResponse.Response.Body.Read.
func (d *RetryableResponse) Read(p []byte) (int, error) {
n, err := d.Response.Body.Read(p)
d.count += int64(n)
if d.count == d.size {
// We read everything, just return as-is.
return n, err
} else if d.count > d.size {
if err != nil {
return n, errors.Wrapf(err, "read beyond the expected end of the response body (%d vs. %d)", d.count, d.size)
}
return n, errors.Errorf("read beyond the expected end of the response body (%d vs. %d)", d.count, d.size)
} else if contextErr := d.req.Context().Err(); contextErr != nil {
// Do not retry on context.Canceled or context.DeadlineExceeded.
return n, contextErr
} else if err != nil {
// We have not read everything, but we got an error. We retry.
errStart := d.start(d.count)
if errStart != nil {
return n, errStart
}
if n > 0 {
return n, nil
}
return d.Read(p)
} else {
// Something else, just return as-is.
return n, err
}
}
// Count implements counter interface for RetryableResponse.
//
// It returns the number of bytes read until now.
func (d *RetryableResponse) Count() int64 {
return d.size
}
// Size returns the expected number of bytes to read.
func (d *RetryableResponse) Size() int64 {
return d.size
}
// Close implements io.Closer interface for RetryableResponse.
//
// It closes the underlying response body.
func (d *RetryableResponse) Close() error {
if d.Response != nil {
err := errors.WithStack(d.Response.Body.Close())
d.Response = nil
return err
}
return nil
}
func (d *RetryableResponse) start(from int64) errors.E {
d.Close()
if from > 0 {
d.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", from))
} else {
d.req.Header.Del("Range")
}
resp, err := d.client.Do(d.req) //nolint:bodyclose
if err != nil {
return errors.WithStack(err)
}
if (from > 0 && resp.StatusCode != http.StatusPartialContent) || (from <= 0 && resp.StatusCode != http.StatusOK) {
body, _ := io.ReadAll(resp.Body)
return errors.Errorf("bad response status (%s): %s", resp.Status, strings.TrimSpace(string(body)))
}
d.Response = resp
lengthStr := resp.Header.Get("Content-Length")
if lengthStr == "" {
return errors.Errorf("missing Content-Length header in response")
}
length, err := strconv.ParseInt(lengthStr, 10, 64) //nolint:gomnd
if err != nil {
return errors.WithStack(err)
}
if from > 0 {
if d.count+length != d.size {
return errors.Errorf("content after retry has different length (%d) than before (%d)", d.count+length, d.size)
}
} else {
d.size = length
}
return nil
}
// NewRetryableResponse returns a RetryableResponse given the client and request to do (and potentially retry).
func NewRetryableResponse(client *retryablehttp.Client, req *retryablehttp.Request) (*RetryableResponse, errors.E) {
r := &RetryableResponse{
client: client,
req: req,
count: 0,
size: 0,
Response: nil,
}
err := r.start(0)
if err != nil {
return nil, err
}
return r, nil
}
package x_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/tozd/go/x"
)
const responseBody = "Hello, client\n"
func TestRetryableResponseSimple(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, responseBody)
}))
defer ts.Close()
ctx := context.Background()
client := retryablehttp.NewClient()
req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
require.NoError(t, err)
require.NotNil(t, req)
res, err := x.NewRetryableResponse(client, req)
require.NoError(t, err)
require.NotNil(t, res)
assert.Equal(t, "14", res.Header.Get("Content-Length"))
assert.Equal(t, int64(14), res.Size())
response, err := io.ReadAll(res)
require.NoError(t, err)
assert.Equal(t, responseBody, string(response))
}
func TestRetryableResponseRetry(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if reqRange := r.Header.Get("Range"); reqRange != "" {
require.True(t, strings.HasPrefix(reqRange, "bytes="))
reqRange = strings.TrimPrefix(reqRange, "bytes=")
rs := strings.Split(reqRange, "-")
require.Equal(t, 2, len(rs))
end := rs[1]
require.Equal(t, "", end)
start, err := strconv.Atoi(rs[0])
require.NoError(t, err)
require.Equal(t, 6, start)
rest := responseBody[start:]
w.Header().Set("Content-Length", strconv.Itoa(len(rest)))
w.WriteHeader(http.StatusPartialContent)
// Send the rest.
fmt.Fprint(w, rest)
} else {
w.Header().Set("Content-Length", strconv.Itoa(len(responseBody)))
w.WriteHeader(http.StatusOK)
// Send only the first 6 bytes.
fmt.Fprint(w, responseBody[0:6])
}
}))
defer ts.Close()
ctx := context.Background()
client := retryablehttp.NewClient()
req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
require.NoError(t, err)
require.NotNil(t, req)
res, err := x.NewRetryableResponse(client, req)
require.NoError(t, err)
require.NotNil(t, res)
assert.Equal(t, "14", res.Header.Get("Content-Length"))
assert.Equal(t, int64(14), res.Size())
data, err := io.ReadAll(res)
require.NoError(t, err)
assert.Equal(t, responseBody, string(data))
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment