Commit 68895ecc authored by Artyom Kartasov's avatar Artyom Kartasov

feat: support personal tokens (#116)

parent ff15218e
......@@ -20,6 +20,7 @@ import (
"gitlab.com/postgres-ai/database-lab/pkg/config"
"gitlab.com/postgres-ai/database-lab/pkg/log"
"gitlab.com/postgres-ai/database-lab/pkg/services/cloning"
"gitlab.com/postgres-ai/database-lab/pkg/services/platform"
"gitlab.com/postgres-ai/database-lab/pkg/services/provision"
"gitlab.com/postgres-ai/database-lab/pkg/srv"
)
......@@ -85,11 +86,16 @@ func main() {
log.Fatalf(err)
}
platformSvc := platform.NewService(cfg.Platform)
if err := platformSvc.Init(ctx); err != nil {
log.Fatalf(errors.WithMessage(err, "failed to create a new platform service"))
}
if len(opts.VerificationToken) > 0 {
cfg.Server.VerificationToken = opts.VerificationToken
}
server := srv.NewServer(&cfg.Server, cloningSvc)
server := srv.NewServer(&cfg.Server, cloningSvc, platformSvc)
if err = server.Run(); err != nil {
log.Fatalf(err)
}
......
......@@ -5,6 +5,18 @@ server:
verificationToken: "secret_token"
port: 3000
# Postgres.ai Platform integration.
platform:
# API URL.
url: "https://postgres.ai/api/general"
# Token for authorization in Platform API. This token can be obtained in Postgres.ai Console:
# see the corresponding organization, "Access tokens" in the left menu.
accessToken: "platform_access_token"
# Enable authorization with personal tokens of the organization's members.
enablePersonalTokens: false
provision:
# Provision mode to use.
mode: "local"
......
......@@ -70,7 +70,7 @@ func NewClient(options Options, logger logrus.FieldLogger) (*Client, error) {
return &Client{
url: u,
verificationToken: options.VerificationToken,
client: &http.Client{Transport: tr},
client: &http.Client{Transport: tr, Timeout: defaultPollingTimeout},
logger: logger,
pollingInterval: defaultPollingInterval,
}, nil
......
/*
2019 © Postgres.ai
*/
// Package platform provides the Platform API client.
package platform
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"github.com/pkg/errors"
"gitlab.com/postgres-ai/database-lab/pkg/log"
)
const (
accessToken = "Access-Token"
)
// APIResponse represents common fields of an API response.
type APIResponse struct {
Hint string `json:"hint"`
Details string `json:"details"`
Code string `json:"code"`
Message string `json:"message"`
}
// Client provides the Platform API client.
type Client struct {
url *url.URL
accessToken string
client *http.Client
}
// ClientConfig describes configuration parameters of Postgres.ai Platform client.
type ClientConfig struct {
URL string
AccessToken string
}
// NewClient creates a new Platform API client.
func NewClient(platformCfg ClientConfig) (*Client, error) {
if err := validateConfig(platformCfg); err != nil {
return nil, err
}
u, err := url.Parse(platformCfg.URL)
if err != nil {
return nil, errors.Wrap(err, "failed to parse the platform host")
}
u.Path = strings.TrimRight(u.Path, "/")
p := Client{
url: u,
accessToken: platformCfg.AccessToken,
client: &http.Client{
Transport: &http.Transport{},
},
}
return &p, nil
}
func validateConfig(config ClientConfig) error {
if config.URL == "" || config.AccessToken == "" {
return errors.New("invalid config of Platform Client given: URL and AccessToken must not be empty")
}
return nil
}
type responseParser func(*http.Response) error
func newJSONParser(v interface{}) responseParser {
return func(resp *http.Response) error {
return json.NewDecoder(resp.Body).Decode(v)
}
}
func (p *Client) doRequest(ctx context.Context, request *http.Request, parser responseParser) error {
request.Header.Add(accessToken, p.accessToken)
request = request.WithContext(ctx)
response, err := p.client.Do(request)
if err != nil {
return errors.Wrap(err, "failed to make API request")
}
defer func() { _ = response.Body.Close() }()
if response.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return errors.Wrap(err, "failed to read response")
}
log.Dbg(fmt.Sprintf("Response: %v", string(body)))
response.Body = ioutil.NopCloser(bytes.NewBuffer(body))
if err := parser(response); err != nil {
return errors.Wrap(err, "failed to parse response")
}
return errors.Errorf("unsuccessful status given: %d", response.StatusCode)
}
return parser(response)
}
func (p *Client) doPost(ctx context.Context, path string, data interface{}, response interface{}) error {
reqData, err := json.Marshal(data)
if err != nil {
return errors.Wrap(err, "failed to marshal request")
}
postURL := p.buildURL(path).String()
r, err := http.NewRequest(http.MethodPost, postURL, bytes.NewBuffer(reqData))
if err != nil {
return errors.Wrap(err, "failed to create request")
}
if err := p.doRequest(ctx, r, newJSONParser(&response)); err != nil {
return errors.Wrap(err, "failed to perform request")
}
return nil
}
// TokenCheckRequest represents token checking request.
type TokenCheckRequest struct {
Token string `json:"token"`
}
// TokenCheckResponse represents a response of a token checking request.
type TokenCheckResponse struct {
APIResponse
OrganizationID uint `json:"org_id"`
Personal bool `json:"is_personal"`
}
// CheckPlatformToken makes an HTTP request to check the Platform Access Token.
func (p *Client) CheckPlatformToken(ctx context.Context, request TokenCheckRequest) (TokenCheckResponse, error) {
respData := TokenCheckResponse{}
if err := p.doPost(ctx, "/rpc/dblab_token_check", request, &respData); err != nil {
return respData, errors.Wrap(err, "failed to post request")
}
if respData.Code != "" || respData.Message != "" {
log.Dbg(fmt.Sprintf("Unsuccessful response given. Request: %v", request))
return respData, errors.Errorf("error: %v", respData)
}
return respData, nil
}
// URL builds URL for a specific endpoint.
func (p *Client) buildURL(urlPath string) *url.URL {
fullPath := path.Join(p.url.Path, urlPath)
u := *p.url
u.Path = fullPath
return &u
}
/*
2019 © Postgres.ai
*/
package platform
import (
"bytes"
"context"
"encoding/json"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// roundTripFunc represents a mock type.
type roundTripFunc func(req *http.Request) *http.Response
// RoundTrip is a mock function.
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// NewTestClient returns a mock of *http.Client.
func NewTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
func TestNewClient(t *testing.T) {
// The test case also checks if the client can be work with a no-ideal URL.
c, err := NewClient(ClientConfig{
URL: "https://example.com//",
AccessToken: "testVerify",
})
require.NoError(t, err)
assert.IsType(t, &Client{}, c)
assert.Equal(t, "https://example.com", c.url.String())
assert.Equal(t, "testVerify", c.accessToken)
assert.IsType(t, &http.Client{}, c.client)
}
func TestClientURL(t *testing.T) {
c, err := NewClient(ClientConfig{
URL: "https://example.com/",
AccessToken: "testVerify",
})
require.NoError(t, err)
assert.Equal(t, "https://example.com/test-url", c.buildURL("test-url").String())
}
func TestClientWithEmptyConfig(t *testing.T) {
testCases := []struct {
url string
token string
}{
{url: "", token: ""},
{url: "non-empty", token: ""},
{url: "", token: "non-empty"},
}
for _, tc := range testCases {
platformClient, err := NewClient(ClientConfig{
URL: tc.url,
AccessToken: tc.token,
})
require.Nil(t, platformClient)
require.NotNil(t, err)
require.Error(t, err, "invalid config of Platform Client given: URL and AccessToken must not be empty")
}
}
func TestClientChecksPlatformToken(t *testing.T) {
expectedResponse := TokenCheckResponse{
OrganizationID: 1,
Personal: true,
}
testClient := NewTestClient(func(req *http.Request) *http.Response {
body, err := json.Marshal(expectedResponse)
require.NoError(t, err)
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBuffer(body)),
}
})
platformClient, err := NewClient(ClientConfig{
URL: "https://example.com/",
AccessToken: "testVerify",
})
require.NoError(t, err)
platformClient.client = testClient
platformToken, err := platformClient.CheckPlatformToken(context.Background(), TokenCheckRequest{Token: "PersonalToken"})
require.NoError(t, err)
assert.Equal(t, expectedResponse.OrganizationID, platformToken.OrganizationID)
assert.Equal(t, expectedResponse.Personal, platformToken.Personal)
}
func TestClientChecksPlatformTokenFailed(t *testing.T) {
expectedResponse := TokenCheckResponse{
APIResponse: APIResponse{
Hint: "Ensure that you use a valid and non-expired token",
Details: "Cannot find the specified token or it is expired.",
Message: "Invalid token",
},
}
testClient := NewTestClient(func(req *http.Request) *http.Response {
body, err := json.Marshal(expectedResponse)
require.NoError(t, err)
return &http.Response{
StatusCode: http.StatusUnauthorized,
Body: ioutil.NopCloser(bytes.NewBuffer(body)),
}
})
platformClient, err := NewClient(ClientConfig{
URL: "https://example.com/",
AccessToken: "testVerify",
})
require.NoError(t, err)
platformClient.client = testClient
platformToken, err := platformClient.CheckPlatformToken(context.Background(), TokenCheckRequest{Token: "PersonalToken"})
require.NotNil(t, err)
assert.Equal(t, expectedResponse.APIResponse.Message, platformToken.Message)
assert.Equal(t, expectedResponse.APIResponse.Hint, platformToken.Hint)
assert.Equal(t, expectedResponse.APIResponse.Details, platformToken.Details)
}
......@@ -11,6 +11,7 @@ import (
"os/user"
"gitlab.com/postgres-ai/database-lab/pkg/services/cloning"
"gitlab.com/postgres-ai/database-lab/pkg/services/platform"
"gitlab.com/postgres-ai/database-lab/pkg/services/provision"
"gitlab.com/postgres-ai/database-lab/pkg/srv"
"gitlab.com/postgres-ai/database-lab/pkg/util"
......@@ -24,6 +25,7 @@ type Config struct {
Server srv.Config `yaml:"server"`
Provision provision.Config `yaml:"provision"`
Cloning cloning.Config `yaml:"cloning"`
Platform platform.Config `yaml:"platform"`
Debug bool `yaml:"debug"`
}
......
......@@ -510,7 +510,7 @@ func (c *baseCloning) destroyIdleClones(ctx context.Context) {
default:
isIdleClone, err := c.isIdleClone(cloneWrapper)
if err != nil {
log.Errf("Failed to check the idleness of clone %s: %+v.", cloneWrapper.clone.ID, err)
log.Errf("Failed to check the idleness of clone %s: %v.", cloneWrapper.clone.ID, err)
continue
}
......
/*
2019 © Postgres.ai
*/
// Package platform provides a Platform service.
package platform
import (
"context"
"github.com/pkg/errors"
"gitlab.com/postgres-ai/database-lab/pkg/client/platform"
)
// PersonalTokenVerifier declares an interface of a struct for Platform Personal Token verification.
type PersonalTokenVerifier interface {
IsAllowedToken(ctx context.Context, token string) bool
IsPersonalTokenEnabled() bool
}
// Config provides configuration for the Platform service.
type Config struct {
URL string `yaml:"url"`
AccessToken string `yaml:"accessToken"`
EnablePersonalToken bool `yaml:"enablePersonalTokens"`
}
// Service defines a Platform service.
type Service struct {
cfg Config
client *platform.Client
organizationID uint
}
// NewService creates a new platform service.
func NewService(cfg Config) *Service {
return &Service{
cfg: cfg,
}
}
// Init initialize a Platform service instance.
func (s *Service) Init(ctx context.Context) error {
if !s.IsPersonalTokenEnabled() {
return nil
}
client, err := platform.NewClient(platform.ClientConfig{
URL: s.cfg.URL,
AccessToken: s.cfg.AccessToken,
})
if err != nil {
return errors.Wrap(err, "failed to create a new Platform Client")
}
s.client = client
platformToken, err := client.CheckPlatformToken(ctx, platform.TokenCheckRequest{Token: s.cfg.AccessToken})
if err != nil {
return err
}
if platformToken.OrganizationID == 0 {
return errors.New("invalid organization ID associated with the given Platform Access Token")
}
s.organizationID = platformToken.OrganizationID
return nil
}
// IsAllowedToken checks if the Platform Personal Token is allowed.
func (s *Service) IsAllowedToken(ctx context.Context, personalToken string) bool {
if !s.IsPersonalTokenEnabled() {
return true
}
platformToken, err := s.client.CheckPlatformToken(ctx, platform.TokenCheckRequest{Token: personalToken})
if err != nil {
return false
}
return s.isAllowedOrganization(platformToken.OrganizationID)
}
// IsPersonalTokenEnabled checks if the Platform Personal Token is enabled.
func (s *Service) IsPersonalTokenEnabled() bool {
return s.cfg.EnablePersonalToken
}
// isAllowedOrganization checks if organization is associated to the current Platform service.
func (s *Service) isAllowedOrganization(organizationID uint) bool {
return organizationID != 0 && organizationID == s.organizationID
}
/*
2019 © Postgres.ai
*/
package platform
import (
"testing"
"github.com/docker/docker/pkg/testutil/assert"
)
func TestIfPersonalTokenEnabled(t *testing.T) {
s := Service{}
assert.Equal(t, s.IsPersonalTokenEnabled(), false)
s.cfg.EnablePersonalToken = true
assert.Equal(t, s.IsPersonalTokenEnabled(), true)
}
func TestIfOrganizationIsAllowed(t *testing.T) {
s := Service{}
assert.Equal(t, s.isAllowedOrganization(0), false)
s.organizationID = 1
assert.Equal(t, s.isAllowedOrganization(0), false)
assert.Equal(t, s.isAllowedOrganization(1), true)
}
/*
2019 © Postgres.ai
*/
package srv
import (
"context"
"net/http"
"gitlab.com/postgres-ai/database-lab/pkg/log"
"gitlab.com/postgres-ai/database-lab/pkg/services/platform"
)
// VerificationTokenHeader defines a verification token name that should be passed in request headers.
......@@ -16,10 +22,16 @@ func logging(next http.Handler) http.Handler {
})
}
func (s *Server) authorized(h http.HandlerFunc) http.HandlerFunc {
// authMW defines an authorization middleware of the Database Lab HTTP server.
type authMW struct {
verificationToken string
personalTokenVerifier platform.PersonalTokenVerifier
}
func (a *authMW) authorized(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(VerificationTokenHeader)
if len(token) == 0 || s.Config.VerificationToken != token {
if !a.isAccessAllowed(r.Context(), token) {
failUnauthorized(w, r)
return
}
......@@ -27,3 +39,19 @@ func (s *Server) authorized(h http.HandlerFunc) http.HandlerFunc {
h(w, r)
}
}
func (a *authMW) isAccessAllowed(ctx context.Context, token string) bool {
if token == "" {
return false
}
if a.verificationToken == token {
return true
}
if a.personalTokenVerifier.IsPersonalTokenEnabled() && a.personalTokenVerifier.IsAllowedToken(ctx, token) {
return true
}
return false
}
/*
2019 © Postgres.ai
*/
package srv
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
// Test constants.
const (
testVerificationToken = "TestToken"
testPlatformAccessToken = "PlatformAccessToken"
)
// MockPersonalTokenVerifier mocks personal verifier methods.
type MockPersonalTokenVerifier struct {
isPersonalTokenEnabled bool
}
func (m MockPersonalTokenVerifier) IsAllowedToken(_ context.Context, token string) bool {
return testPlatformAccessToken == token
}
func (m MockPersonalTokenVerifier) IsPersonalTokenEnabled() bool {
return m.isPersonalTokenEnabled
}
func TestAccess(t *testing.T) {
testCases := []struct {
name string
requestToken string
isPersonalTokenEnabled bool
result bool
}{
{isPersonalTokenEnabled: false, requestToken: "", result: false, name: "empty RequestToken with disabled PersonalToken"},
{isPersonalTokenEnabled: false, requestToken: "WrongToken", result: false, name: "wrong RequestToken with disabled PersonalToken"},
{isPersonalTokenEnabled: false, requestToken: "TestToken", result: true, name: "correct RequestToken with disabled PersonalToken"},
{isPersonalTokenEnabled: true, requestToken: "", result: false, name: "empty RequestToken with enabled PersonalToken"},
{isPersonalTokenEnabled: true, requestToken: "WrongToken", result: false, name: "wrong RequestToken with enabled PersonalToken"},
{isPersonalTokenEnabled: true, requestToken: "TestToken", result: true, name: "correct RequestToken with enabled PersonalToken"},
{isPersonalTokenEnabled: true, requestToken: "PlatformAccessToken", result: true, name: "correct PersonalToken with enabled PersonalToken"},
}
mw := authMW{
verificationToken: testVerificationToken,
}
for _, tc := range testCases {
t.Log(tc.name)
mw.personalTokenVerifier = MockPersonalTokenVerifier{isPersonalTokenEnabled: tc.result}
isAllowed := mw.isAccessAllowed(context.Background(), tc.requestToken)
assert.Equal(t, tc.result, isAllowed)
}
}
......@@ -13,32 +13,32 @@ import (
"gitlab.com/postgres-ai/database-lab/pkg/log"
"gitlab.com/postgres-ai/database-lab/pkg/services/cloning"
"gitlab.com/postgres-ai/database-lab/pkg/services/platform"
"gitlab.com/postgres-ai/database-lab/pkg/util"
"github.com/gorilla/mux"
)
// Config provides configuration for an HTTP server of the Database Lab.
type Config struct {
VerificationToken string `yaml:"verificationToken"`
Port uint `yaml:"port"`
}
// Server defines an HTTP server of the Database Lab.
type Server struct {
Config *Config
Cloning cloning.Cloning
Config *Config
Cloning cloning.Cloning
Platform *platform.Service
}
type Route struct {
Route string `json:"route"`
Methods []string `json:"methods"`
}