Verified Commit f829cd99 authored by Grzegorz Bizon's avatar Grzegorz Bizon 💡 Committed by Nick Thomas

Limit memory footprint of artifacts metadata processing

parent f04a5c99
---
title: Limit memory footprint of a command that generates ZIP artifacts metadata
merge_request:
author:
type: security
......@@ -3,6 +3,7 @@ package main
import (
"archive/zip"
"context"
"errors"
"flag"
"fmt"
"io"
......@@ -42,7 +43,7 @@ func main() {
fileName, err := zipartifacts.DecodeFileEntry(encodedFileName)
if err != nil {
fatalError(fmt.Errorf("decode entry %q: %v", encodedFileName, err))
fatalError(fmt.Errorf("decode entry %q", encodedFileName), err)
}
ctx, cancel := context.WithCancel(context.Background())
......@@ -50,30 +51,26 @@ func main() {
archive, err := zipartifacts.OpenArchive(ctx, archivePath)
if err != nil {
oaError := fmt.Errorf("OpenArchive: %v", err)
if err == zipartifacts.ErrArchiveNotFound {
notFoundError(oaError)
}
fatalError(oaError)
fatalError(errors.New("open archive"), err)
}
file := findFileInZip(fileName, archive)
if file == nil {
notFoundError(fmt.Errorf("find %q in %q: not found", fileName, scrubbedArchivePath))
fatalError(fmt.Errorf("find %q in %q: not found", fileName, scrubbedArchivePath), zipartifacts.ErrorCode[zipartifacts.CodeEntryNotFound])
}
// Start decompressing the file
reader, err := file.Open()
if err != nil {
fatalError(fmt.Errorf("open %q in %q: %v", fileName, scrubbedArchivePath, err))
fatalError(fmt.Errorf("open %q in %q", fileName, scrubbedArchivePath), err)
}
defer reader.Close()
if _, err := fmt.Printf("%d\n", file.UncompressedSize64); err != nil {
fatalError(fmt.Errorf("write file size: %v", err))
fatalError(fmt.Errorf("write file size invalid"), err)
}
if _, err := io.Copy(os.Stdout, reader); err != nil {
fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, scrubbedArchivePath, err))
fatalError(fmt.Errorf("write %q from %q to stdout", fileName, scrubbedArchivePath), err)
}
}
......@@ -86,16 +83,14 @@ func findFileInZip(fileName string, archive *zip.Reader) *zip.File {
return nil
}
func printError(err error) {
fmt.Fprintf(os.Stderr, "%s: %v", progName, err)
}
func fatalError(contextErr error, statusErr error) {
code := zipartifacts.ExitCodeByError(statusErr)
func fatalError(err error) {
printError(err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "%s error: %v - %v, code: %d\n", progName, statusErr, contextErr, code)
func notFoundError(err error) {
printError(err)
os.Exit(zipartifacts.StatusEntryNotFound)
if code > 0 {
os.Exit(code)
} else {
os.Exit(1)
}
}
package limit
import (
"errors"
"io"
"sync/atomic"
)
var ErrLimitExceeded = errors.New("reader limit exceeded")
const megabyte = 1 << 20
// LimitedReaderAt supports running a callback in case of reaching a read limit
// (bytes), and allows using a smaller limit than a defined offset for a read.
type LimitedReaderAt struct {
read int64
limit int64
parent io.ReaderAt
limitFunc func(int64)
}
func (r *LimitedReaderAt) ReadAt(p []byte, off int64) (int, error) {
if max := r.limit - r.read; int64(len(p)) > max {
p = p[0:max]
}
n, err := r.parent.ReadAt(p, off)
atomic.AddInt64(&r.read, int64(n))
if r.read >= r.limit {
r.limitFunc(r.read)
return n, ErrLimitExceeded
}
return n, err
}
func NewLimitedReaderAt(reader io.ReaderAt, limit int64, limitFunc func(int64)) io.ReaderAt {
return &LimitedReaderAt{parent: reader, limit: limit, limitFunc: limitFunc}
}
// SizeToLimit tries to dermine an appropriate limit in bytes for an archive of
// a given size. If the size is less than 1 gigabyte we always limit a reader
// to 100 megabytes, otherwise the limit is 10% of a given size.
func SizeToLimit(size int64) int64 {
if size <= 1024*megabyte {
return 100 * megabyte
}
return size / 10
}
package limit
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestReadAt(t *testing.T) {
t.Run("when limit has not been reached", func(t *testing.T) {
r := strings.NewReader("some string to read")
buf := make([]byte, 11)
reader := NewLimitedReaderAt(r, 32, func(n int64) {
require.Zero(t, n)
})
p, err := reader.ReadAt(buf, 0)
require.NoError(t, err)
require.Equal(t, 11, p)
require.Equal(t, "some string", string(buf))
})
t.Run("when read limit is exceeded", func(t *testing.T) {
r := strings.NewReader("some string to read")
buf := make([]byte, 11)
reader := NewLimitedReaderAt(r, 9, func(n int64) {
require.Equal(t, 9, int(n))
})
p, err := reader.ReadAt(buf, 0)
require.Error(t, err)
require.Equal(t, 9, p)
require.Equal(t, "some stri\x00\x00", string(buf))
})
t.Run("when offset is higher than a limit", func(t *testing.T) {
r := strings.NewReader("some string to read")
buf := make([]byte, 4)
reader := NewLimitedReaderAt(r, 5, func(n int64) {
require.Zero(t, n)
})
p, err := reader.ReadAt(buf, 15)
require.NoError(t, err)
require.Equal(t, 4, p)
require.Equal(t, "read", string(buf))
})
t.Run("when a read starts at the limit", func(t *testing.T) {
r := strings.NewReader("some string to read")
buf := make([]byte, 11)
reader := NewLimitedReaderAt(r, 10, func(n int64) {
require.Equal(t, 10, int(n))
})
reader.ReadAt(buf, 0)
p, err := reader.ReadAt(buf, 0)
require.EqualError(t, err, ErrLimitExceeded.Error())
require.Equal(t, 0, p)
require.Equal(t, "some strin\x00", string(buf))
})
}
func TestSizeToLimit(t *testing.T) {
tests := []struct {
size int64
limit int64
name string
}{
{size: 1, limit: 104857600, name: "1b to 100mb"},
{size: 100, limit: 104857600, name: "100b to 100mb"},
{size: 104857600, limit: 104857600, name: "100mb to 100mb"},
{size: 1073741824, limit: 104857600, name: "1gb to 100mb"},
{size: 10737418240, limit: 1073741824, name: "10gb to 1gb"},
{size: 53687091200, limit: 5368709120, name: "50gb to 5gb"},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.Equal(t, test.limit, SizeToLimit(test.size))
})
}
}
......@@ -4,8 +4,10 @@ import (
"context"
"flag"
"fmt"
"io"
"os"
"gitlab.com/gitlab-org/gitlab-workhorse/cmd/gitlab-zip-metadata/limit"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
)
......@@ -29,10 +31,20 @@ func main() {
os.Exit(1)
}
readerFunc := func(reader io.ReaderAt, size int64) io.ReaderAt {
readLimit := limit.SizeToLimit(size)
return limit.NewLimitedReaderAt(reader, readLimit, func(read int64) {
fmt.Fprintf(os.Stderr, "%s: zip archive limit exceeded after reading %d bytes\n", progName, read)
fatalError(zipartifacts.ErrorCode[zipartifacts.CodeLimitsReached])
})
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
archive, err := zipartifacts.OpenArchive(ctx, os.Args[1])
archive, err := zipartifacts.OpenArchiveWithReaderFunc(ctx, os.Args[1], readerFunc)
if err != nil {
fatalError(err)
}
......@@ -43,9 +55,13 @@ func main() {
}
func fatalError(err error) {
fmt.Fprintf(os.Stderr, "%s: %v\n", progName, err)
if err == zipartifacts.ErrNotAZip {
os.Exit(zipartifacts.StatusNotZip)
code := zipartifacts.ExitCodeByError(err)
fmt.Fprintf(os.Stderr, "%s error: %v, code: %d\n", progName, err, code)
if code > 0 {
os.Exit(code)
} else {
os.Exit(1)
}
os.Exit(1)
}
......@@ -10,6 +10,8 @@ import (
"os/exec"
"syscall"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
......@@ -17,6 +19,16 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
)
var zipSubcommandsErrorsCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "gitlab_workhorse_zip_subcommand_errors_total",
Help: "Errors comming from subcommands used for processing ZIP archives",
}, []string{"error"})
func init() {
prometheus.MustRegister(zipSubcommandsErrorsCounter)
}
type artifactsUploadProcessor struct {
opts *filestore.SaveFileOpts
......@@ -63,10 +75,21 @@ func (a *artifactsUploadProcessor) generateMetadataFromZip(ctx context.Context,
}()
if err := zipMd.Wait(); err != nil {
if st, ok := helper.ExitStatus(err); ok && st == zipartifacts.StatusNotZip {
st, ok := helper.ExitStatus(err)
if !ok {
return nil, err
}
zipSubcommandsErrorsCounter.WithLabelValues(zipartifacts.ErrorLabelByCode(st)).Inc()
if st == zipartifacts.CodeNotZip {
return nil, nil
}
return nil, err
if st == zipartifacts.CodeLimitsReached {
return nil, zipartifacts.ErrBadMetadata
}
}
metaWriter.Close()
......@@ -93,7 +116,7 @@ func (a *artifactsUploadProcessor) ProcessFile(ctx context.Context, formName str
// TODO: can we rely on disk for shipping metadata? Not if we split workhorse and rails in 2 different PODs
metadata, err := a.generateMetadataFromZip(ctx, file)
if err != nil {
return fmt.Errorf("generateMetadataFromZip: %v", err)
return err
}
if metadata != nil {
......@@ -109,6 +132,7 @@ func (a *artifactsUploadProcessor) ProcessFile(ctx context.Context, formName str
a.Track("metadata", metadata.LocalPath)
}
}
return nil
}
......
......@@ -112,7 +112,9 @@ func waitCatFile(cmd *exec.Cmd) error {
return nil
}
if st, ok := helper.ExitStatus(err); ok && st == zipartifacts.StatusEntryNotFound {
st, ok := helper.ExitStatus(err)
if ok && (st == zipartifacts.CodeArchiveNotFound || st == zipartifacts.CodeEntryNotFound) {
return os.ErrNotExist
}
return fmt.Errorf("wait for %v to finish: %v", cmd.Args, err)
......
......@@ -12,6 +12,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload/exif"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
)
// These methods are allowed to have thread-unsafe implementations.
......@@ -43,6 +44,8 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, p
h.ServeHTTP(w, r)
case filestore.ErrEntityTooLarge:
helper.RequestEntityTooLarge(w, r, err)
case zipartifacts.ErrBadMetadata:
helper.RequestEntityTooLarge(w, r, err)
case exif.ErrRemovingExif:
helper.CaptureAndFail(w, r, err, "Failed to process image", http.StatusUnprocessableEntity)
default:
......
package zipartifacts
// These are exit codes used by subprocesses in cmd/gitlab-zip-xxx
const (
StatusNotZip = 10 + iota
StatusEntryNotFound
)
package zipartifacts
import (
"errors"
)
// These are exit codes used by subprocesses in cmd/gitlab-zip-xxx. We also use
// them to map errors and error messages that we use as label in Prometheus.
const (
CodeNotZip = 10 + iota
CodeEntryNotFound
CodeArchiveNotFound
CodeLimitsReached
CodeUnknownError
)
var (
ErrorCode = map[int]error{
CodeNotZip: errors.New("zip archive format invalid"),
CodeEntryNotFound: errors.New("zip entry not found"),
CodeArchiveNotFound: errors.New("zip archive not found"),
CodeLimitsReached: errors.New("zip processing limits reached"),
CodeUnknownError: errors.New("zip processing unknown error"),
}
ErrorLabel = map[int]string{
CodeNotZip: "archive_invalid",
CodeEntryNotFound: "entry_not_found",
CodeArchiveNotFound: "archive_not_found",
CodeLimitsReached: "limits_reached",
CodeUnknownError: "unknown_error",
}
ErrBadMetadata = errors.New("zip artifacts metadata invalid")
)
// ExitCodeByError find an os.Exit code for a corresponding error.
// CodeUnkownError in case it can not be found.
func ExitCodeByError(err error) int {
for c, e := range ErrorCode {
if err == e {
return c
}
}
return CodeUnknownError
}
// ErrorLabelByCode returns a Prometheus counter label associated with an exit code.
func ErrorLabelByCode(code int) string {
label, ok := ErrorLabel[code]
if ok {
return label
}
return ErrorLabel[CodeUnknownError]
}
package zipartifacts
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func TestExitCodeByError(t *testing.T) {
t.Run("when error has been recognized", func(t *testing.T) {
code := ExitCodeByError(ErrorCode[CodeLimitsReached])
require.Equal(t, code, CodeLimitsReached)
require.Greater(t, code, 10)
})
t.Run("when error is an unknown one", func(t *testing.T) {
code := ExitCodeByError(errors.New("unknown error"))
require.Equal(t, code, CodeUnknownError)
require.Greater(t, code, 10)
})
}
func TestErrorLabels(t *testing.T) {
for code := range ErrorCode {
_, ok := ErrorLabel[code]
require.True(t, ok)
}
}
......@@ -100,5 +100,5 @@ func TestErrNotAZip(t *testing.T) {
defer cancel()
_, err = zipartifacts.OpenArchive(ctx, f.Name())
assert.Equal(t, zipartifacts.ErrNotAZip, err, "OpenArchive requires a zip file")
assert.Equal(t, zipartifacts.ErrorCode[zipartifacts.CodeNotZip], err, "OpenArchive requires a zip file")
}
......@@ -3,8 +3,8 @@ package zipartifacts
import (
"archive/zip"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
......@@ -18,12 +18,6 @@ import (
"gitlab.com/gitlab-org/labkit/tracing"
)
// ErrNotAZip will be used when the file is not a zip archive
var ErrNotAZip = errors.New("not a zip")
// ErrArchiveNotFound will be used when the file can't be found
var ErrArchiveNotFound = errors.New("archive not found")
var httpClient = &http.Client{
Transport: tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
Proxy: http.ProxyFromEnvironment,
......@@ -38,23 +32,52 @@ var httpClient = &http.Client{
})),
}
type archive struct {
reader io.ReaderAt
size int64
}
// OpenArchive will open a zip.Reader from a local path or a remote object store URL
// in case of remote url it will make use of ranged requestes to support seeking.
// If the path do not exists error will be ErrArchiveNotFound,
// if the file isn't a zip archive error will be ErrNotAZip
func OpenArchive(ctx context.Context, archivePath string) (*zip.Reader, error) {
if isURL(archivePath) {
return openHTTPArchive(ctx, archivePath)
archive, err := openArchiveLocation(ctx, archivePath)
if err != nil {
return nil, err
}
return openFileArchive(ctx, archivePath)
return openZipReader(archive.reader, archive.size)
}
// OpenArchiveWithReaderFunc opens a zip.Reader from either local path or a
// remote object, similarly to OpenArchive function. The difference is that it
// allows passing a readerFunc that takes a io.ReaderAt that is either going to
// be os.File or a custom reader we use to read from object storage. The
// readerFunc can augment the archive reader and return a type that satisfies
// io.ReaderAt.
func OpenArchiveWithReaderFunc(ctx context.Context, location string, readerFunc func(io.ReaderAt, int64) io.ReaderAt) (*zip.Reader, error) {
archive, err := openArchiveLocation(ctx, location)
if err != nil {
return nil, err
}
return openZipReader(readerFunc(archive.reader, archive.size), archive.size)
}
func openArchiveLocation(ctx context.Context, location string) (*archive, error) {
if isURL(location) {
return openHTTPArchive(ctx, location)
}
return openFileArchive(ctx, location)
}
func isURL(path string) bool {
return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
}
func openHTTPArchive(ctx context.Context, archivePath string) (*zip.Reader, error) {
func openHTTPArchive(ctx context.Context, archivePath string) (*archive, error) {
scrubbedArchivePath := mask.URL(archivePath)
req, err := http.NewRequest(http.MethodGet, archivePath, nil)
if err != nil {
......@@ -66,7 +89,7 @@ func openHTTPArchive(ctx context.Context, archivePath string) (*zip.Reader, erro
if err != nil {
return nil, fmt.Errorf("HTTP GET %q: %v", scrubbedArchivePath, err)
} else if resp.StatusCode == http.StatusNotFound {
return nil, ErrArchiveNotFound
return nil, ErrorCode[CodeArchiveNotFound]
} else if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP GET %q: %d: %v", scrubbedArchivePath, resp.StatusCode, resp.Status)
}
......@@ -79,28 +102,36 @@ func openHTTPArchive(ctx context.Context, archivePath string) (*zip.Reader, erro
rs.Close()
}()
archive, err := zip.NewReader(rs, resp.ContentLength)
if err != nil {
return nil, ErrNotAZip
}
return archive, nil
return &archive{reader: rs, size: resp.ContentLength}, nil
}
func openFileArchive(ctx context.Context, archivePath string) (*zip.Reader, error) {
archive, err := zip.OpenReader(archivePath)
func openFileArchive(ctx context.Context, archivePath string) (*archive, error) {
file, err := os.Open(archivePath)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrArchiveNotFound
return nil, ErrorCode[CodeArchiveNotFound]
}
return nil, ErrNotAZip
}
go func() {