Skip to content
Snippets Groups Projects
Commit 67ec679d authored by Stephen Nelson's avatar Stephen Nelson Committed by Stephen Nelson
Browse files

Add support for TLS client authentication

Allow gitlab-ci to connect to a gitlab host using TLS client authentication
(mutual authentication). Adds configuration and support for using TLS client
certificates when using go's TLS transport layer and also sets git enviromental
variables for runners.
parent ef0ca338
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !157. Comments created here will be created in the context of that merge request.
...@@ -381,11 +381,36 @@ func (b *Build) GetDefaultVariables() JobVariables { ...@@ -381,11 +381,36 @@ func (b *Build) GetDefaultVariables() JobVariables {
} }
} }
func (b *Build) GetCITLSVariables() JobVariables {
variables := JobVariables{}
if b.TLSCAChain != "" {
variables = append(variables, JobVariable{"CI_SERVER_TLS_CA_FILE", b.TLSCAChain, true, true, true})
}
if b.TLSAuthCert != "" && b.TLSAuthKey != "" {
variables = append(variables, JobVariable{"CI_SERVER_TLS_CERT_FILE", b.TLSAuthCert, true, true, true})
variables = append(variables, JobVariable{"CI_SERVER_TLS_KEY_FILE", b.TLSAuthKey, true, true, true})
}
return variables
}
func (b *Build) GetGitTLSVariables() JobVariables {
variables := JobVariables{}
if b.TLSCAChain != "" {
variables = append(variables, JobVariable{"GIT_SSL_CAINFO", b.TLSCAChain, true, true, true})
}
if b.TLSAuthCert != "" && b.TLSAuthKey != "" {
variables = append(variables, JobVariable{"GIT_SSL_CERT", b.TLSAuthCert, true, true, true})
variables = append(variables, JobVariable{"GIT_SSL_KEY", b.TLSAuthKey, true, true, true})
}
return variables
}
func (b *Build) GetAllVariables() (variables JobVariables) { func (b *Build) GetAllVariables() (variables JobVariables) {
if b.Runner != nil { if b.Runner != nil {
variables = append(variables, b.Runner.GetVariables()...) variables = append(variables, b.Runner.GetVariables()...)
} }
variables = append(variables, b.GetDefaultVariables()...) variables = append(variables, b.GetDefaultVariables()...)
variables = append(variables, b.GetCITLSVariables()...)
variables = append(variables, b.Variables...) variables = append(variables, b.Variables...)
return variables.Expand() return variables.Expand()
} }
......
...@@ -148,9 +148,11 @@ type KubernetesConfig struct { ...@@ -148,9 +148,11 @@ type KubernetesConfig struct {
} }
type RunnerCredentials struct { type RunnerCredentials struct {
URL string `toml:"url" json:"url" short:"u" long:"url" env:"CI_SERVER_URL" required:"true" description:"Runner URL"` URL string `toml:"url" json:"url" short:"u" long:"url" env:"CI_SERVER_URL" required:"true" description:"Runner URL"`
Token string `toml:"token" json:"token" short:"t" long:"token" env:"CI_SERVER_TOKEN" required:"true" description:"Runner token"` Token string `toml:"token" json:"token" short:"t" long:"token" env:"CI_SERVER_TOKEN" required:"true" description:"Runner token"`
TLSCAFile string `toml:"tls-ca-file,omitempty" json:"tls-ca-file" long:"tls-ca-file" env:"CI_SERVER_TLS_CA_FILE" description:"File containing the certificates to verify the peer when using HTTPS"` TLSCAFile string `toml:"tls-ca-file,omitempty" json:"tls-ca-file" long:"tls-ca-file" env:"CI_SERVER_TLS_CA_FILE" description:"File containing the certificates to verify the peer when using HTTPS"`
TLSCertFile string `toml:"tls-cert-file,omitempty" json:"tls-cert-file" long:"tls-cert-file" env:"CI_SERVER_TLS_CERT_FILE" description:"File containing certificate for TLS client auth when using HTTPS"`
TLSKeyFile string `toml:"tls-key-file,omitempty" json:"tls-key-file" long:"tls-key-file" env:"CI_SERVER_TLS_KEY_FILE" description:"File containing private key for TLS client auth when using HTTPS"`
} }
type CacheConfig struct { type CacheConfig struct {
...@@ -277,6 +279,14 @@ func (c *RunnerCredentials) GetTLSCAFile() string { ...@@ -277,6 +279,14 @@ func (c *RunnerCredentials) GetTLSCAFile() string {
return c.TLSCAFile return c.TLSCAFile
} }
func (c *RunnerCredentials) GetTLSCertFile() string {
return c.TLSCertFile
}
func (c *RunnerCredentials) GetTLSKeyFile() string {
return c.TLSKeyFile
}
func (c *RunnerCredentials) GetToken() string { func (c *RunnerCredentials) GetToken() string {
return c.Token return c.Token
} }
......
...@@ -209,7 +209,9 @@ type JobResponse struct { ...@@ -209,7 +209,9 @@ type JobResponse struct {
Credentials []Credentials `json:"credentials"` Credentials []Credentials `json:"credentials"`
Dependencies Dependencies `json:"dependencies"` Dependencies Dependencies `json:"dependencies"`
TLSCAChain string `json:"-"` TLSCAChain string `json:"-"`
TLSAuthCert string `json:"-"`
TLSAuthKey string `json:"-"`
} }
func (j *JobResponse) RepoCleanURL() string { func (j *JobResponse) RepoCleanURL() string {
...@@ -224,10 +226,12 @@ type UpdateJobRequest struct { ...@@ -224,10 +226,12 @@ type UpdateJobRequest struct {
} }
type JobCredentials struct { type JobCredentials struct {
ID int `long:"id" env:"CI_BUILD_ID" description:"The build ID to upload artifacts for"` ID int `long:"id" env:"CI_BUILD_ID" description:"The build ID to upload artifacts for"`
Token string `long:"token" env:"CI_BUILD_TOKEN" required:"true" description:"Build token"` Token string `long:"token" env:"CI_BUILD_TOKEN" required:"true" description:"Build token"`
URL string `long:"url" env:"CI_SERVER_URL" required:"true" description:"GitLab CI URL"` URL string `long:"url" env:"CI_SERVER_URL" required:"true" description:"GitLab CI URL"`
TLSCAFile string `long:"tls-ca-file" env:"CI_SERVER_TLS_CA_FILE" description:"File containing the certificates to verify the peer when using HTTPS"` TLSCAFile string `long:"tls-ca-file" env:"CI_SERVER_TLS_CA_FILE" description:"File containing the certificates to verify the peer when using HTTPS"`
TLSCertFile string `long:"tls-cert-file" env:"CI_SERVER_TLS_CERT_FILE" description:"File containing certificate for TLS client auth with runner when using HTTPS"`
TLSKeyFile string `long:"tls-key-file" env:"CI_SERVER_TLS_KEY_FILE" description:"File containing private key for TLS client auth with runner when using HTTPS"`
} }
func (j *JobCredentials) GetURL() string { func (j *JobCredentials) GetURL() string {
...@@ -238,6 +242,14 @@ func (j *JobCredentials) GetTLSCAFile() string { ...@@ -238,6 +242,14 @@ func (j *JobCredentials) GetTLSCAFile() string {
return j.TLSCAFile return j.TLSCAFile
} }
func (j *JobCredentials) GetTLSCertFile() string {
return j.TLSCertFile
}
func (j *JobCredentials) GetTLSKeyFile() string {
return j.TLSKeyFile
}
func (j *JobCredentials) GetToken() string { func (j *JobCredentials) GetToken() string {
return j.Token return j.Token
} }
......
...@@ -37,6 +37,8 @@ This defines one runner entry. ...@@ -37,6 +37,8 @@ This defines one runner entry.
| `url` | CI URL | | `url` | CI URL |
| `token` | runner token | | `token` | runner token |
| `tls-ca-file` | file containing the certificates to verify the peer when using HTTPS | | `tls-ca-file` | file containing the certificates to verify the peer when using HTTPS |
| `tls-cert-file` | file containing the certificate to authenticate with the peer when using HTTPS |
| `tls-key-file` | file containing the private key to authenticate with the peer when using HTTPS |
| `tls-skip-verify` | whether to verify the TLS certificate when using HTTPS, default: false | | `tls-skip-verify` | whether to verify the TLS certificate when using HTTPS, default: false |
| `limit` | limit how many jobs can be handled concurrently by this token. `0` (default) simply means don't limit | | `limit` | limit how many jobs can be handled concurrently by this token. `0` (default) simply means don't limit |
| `executor` | select how a project should be built, see next section | | `executor` | select how a project should be built, see next section |
......
...@@ -29,6 +29,8 @@ type requestCredentials interface { ...@@ -29,6 +29,8 @@ type requestCredentials interface {
GetURL() string GetURL() string
GetToken() string GetToken() string
GetTLSCAFile() string GetTLSCAFile() string
GetTLSCertFile() string
GetTLSKeyFile() string
} }
var dialer = net.Dialer{ var dialer = net.Dialer{
...@@ -40,6 +42,8 @@ type client struct { ...@@ -40,6 +42,8 @@ type client struct {
http.Client http.Client
url *url.URL url *url.URL
caFile string caFile string
certFile string
keyFile string
caData []byte caData []byte
skipVerify bool skipVerify bool
updateTime time.Time updateTime time.Time
...@@ -62,6 +66,16 @@ func (n *client) ensureTLSConfig() { ...@@ -62,6 +66,16 @@ func (n *client) ensureTLSConfig() {
n.Transport = nil n.Transport = nil
} }
// client certificate got modified
if stat, err := os.Stat(n.certFile); err == nil && n.updateTime.Before(stat.ModTime()) {
n.Transport = nil
}
// client private key got modified
if stat, err := os.Stat(n.keyFile); err == nil && n.updateTime.Before(stat.ModTime()) {
n.Transport = nil
}
// create or update transport // create or update transport
if n.Transport == nil { if n.Transport == nil {
n.updateTime = time.Now() n.updateTime = time.Now()
...@@ -69,14 +83,8 @@ func (n *client) ensureTLSConfig() { ...@@ -69,14 +83,8 @@ func (n *client) ensureTLSConfig() {
} }
} }
func (n *client) createTransport() { func (n *client) addTLSCA(tlsConfig *tls.Config) {
// create reference TLS config // load TLS CA certificate
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: n.skipVerify,
}
// load TLS certificate
if file := n.caFile; file != "" && !n.skipVerify { if file := n.caFile; file != "" && !n.skipVerify {
logrus.Debugln("Trying to load", file, "...") logrus.Debugln("Trying to load", file, "...")
...@@ -95,6 +103,34 @@ func (n *client) createTransport() { ...@@ -95,6 +103,34 @@ func (n *client) createTransport() {
} }
} }
} }
}
func (n *client) addTLSAuth(tlsConfig *tls.Config) {
// load TLS client keypair
if cert, key := n.certFile, n.keyFile; cert != "" && key != "" {
logrus.Debugln("Trying to load", cert, "and", key, "pair...")
certificate, err := tls.LoadX509KeyPair(cert, key)
if err == nil {
tlsConfig.Certificates = []tls.Certificate{certificate}
tlsConfig.BuildNameToCertificate()
} else {
if !os.IsNotExist(err) {
logrus.Errorln("Failed to load", cert, key, err)
}
}
}
}
func (n *client) createTransport() {
// create reference TLS config
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: n.skipVerify,
}
n.addTLSCA(&tlsConfig)
n.addTLSAuth(&tlsConfig)
// create transport // create transport
n.Transport = &http.Transport{ n.Transport = &http.Transport{
...@@ -174,13 +210,13 @@ func (n *client) do(uri, method string, request io.Reader, requestType string, h ...@@ -174,13 +210,13 @@ func (n *client) do(uri, method string, request io.Reader, requestType string, h
return return
} }
func (n *client) doJSON(uri, method string, statusCode int, request interface{}, response interface{}) (int, string, string) { func (n *client) doJSON(uri, method string, statusCode int, request interface{}, response interface{}) (int, string, string, string, string) {
var body io.Reader var body io.Reader
if request != nil { if request != nil {
requestBody, err := json.Marshal(request) requestBody, err := json.Marshal(request)
if err != nil { if err != nil {
return -1, fmt.Sprintf("failed to marshal project object: %v", err), "" return -1, fmt.Sprintf("failed to marshal project object: %v", err), "", "", ""
} }
body = bytes.NewReader(requestBody) body = bytes.NewReader(requestBody)
} }
...@@ -192,7 +228,7 @@ func (n *client) doJSON(uri, method string, statusCode int, request interface{}, ...@@ -192,7 +228,7 @@ func (n *client) doJSON(uri, method string, statusCode int, request interface{},
res, err := n.do(uri, method, body, "application/json", headers) res, err := n.do(uri, method, body, "application/json", headers)
if err != nil { if err != nil {
return -1, err.Error(), "" return -1, err.Error(), "", "", ""
} }
defer res.Body.Close() defer res.Body.Close()
defer io.Copy(ioutil.Discard, res.Body) defer io.Copy(ioutil.Discard, res.Body)
...@@ -201,20 +237,20 @@ func (n *client) doJSON(uri, method string, statusCode int, request interface{}, ...@@ -201,20 +237,20 @@ func (n *client) doJSON(uri, method string, statusCode int, request interface{},
if response != nil { if response != nil {
isApplicationJSON, err := isResponseApplicationJSON(res) isApplicationJSON, err := isResponseApplicationJSON(res)
if !isApplicationJSON { if !isApplicationJSON {
return -1, err.Error(), "" return -1, err.Error(), "", "", ""
} }
d := json.NewDecoder(res.Body) d := json.NewDecoder(res.Body)
err = d.Decode(response) err = d.Decode(response)
if err != nil { if err != nil {
return -1, fmt.Sprintf("Error decoding json payload %v", err), "" return -1, fmt.Sprintf("Error decoding json payload %v", err), "", "", ""
} }
} }
} }
n.setLastUpdate(res.Header) n.setLastUpdate(res.Header)
return res.StatusCode, res.Status, n.getCAChain(res.TLS) return res.StatusCode, res.Status, n.getCAChain(res.TLS), n.certFile, n.keyFile
} }
func isResponseApplicationJSON(res *http.Response) (result bool, err error) { func isResponseApplicationJSON(res *http.Response) (result bool, err error) {
...@@ -240,6 +276,16 @@ func fixCIURL(url string) string { ...@@ -240,6 +276,16 @@ func fixCIURL(url string) string {
return url return url
} }
func (n *client) findCertificate(certificate *string, base string, name string) {
if *certificate != "" {
return
}
path := filepath.Join(base, name)
if _, err := os.Stat(path); err == nil {
*certificate = path
}
}
func newClient(requestCredentials requestCredentials) (c *client, err error) { func newClient(requestCredentials requestCredentials) (c *client, err error) {
url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/") url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/")
if err != nil { if err != nil {
...@@ -252,13 +298,17 @@ func newClient(requestCredentials requestCredentials) (c *client, err error) { ...@@ -252,13 +298,17 @@ func newClient(requestCredentials requestCredentials) (c *client, err error) {
} }
c = &client{ c = &client{
url: url, url: url,
caFile: requestCredentials.GetTLSCAFile(), caFile: requestCredentials.GetTLSCAFile(),
certFile: requestCredentials.GetTLSCertFile(),
keyFile: requestCredentials.GetTLSKeyFile(),
} }
if CertificateDirectory != "" && c.caFile == "" { host := strings.Split(url.Host, ":")[0]
hostAndPort := strings.Split(url.Host, ":") if CertificateDirectory != "" {
c.caFile = filepath.Join(CertificateDirectory, hostAndPort[0]+".crt") c.findCertificate(&c.caFile, CertificateDirectory, host+".crt")
c.findCertificate(&c.certFile, CertificateDirectory, host+".auth.crt")
c.findCertificate(&c.keyFile, CertificateDirectory, host+".auth.key")
} }
return return
......
package network package network
import ( import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
...@@ -56,6 +59,33 @@ func writeTLSCertificate(s *httptest.Server, file string) error { ...@@ -56,6 +59,33 @@ func writeTLSCertificate(s *httptest.Server, file string) error {
return ioutil.WriteFile(file, encoded, 0600) return ioutil.WriteFile(file, encoded, 0600)
} }
func writeTLSKeyPair(s *httptest.Server, certFile string, keyFile string) error {
c := s.TLS.Certificates[0]
if c.Certificate == nil || c.Certificate[0] == nil {
return errors.New("no predefined certificate")
}
encodedCert := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: c.Certificate[0],
})
if err := ioutil.WriteFile(certFile, encodedCert, 0600); err != nil {
return err
}
switch k := c.PrivateKey.(type) {
case *rsa.PrivateKey:
encodedKey := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(k),
})
return ioutil.WriteFile(keyFile, encodedKey, 0600)
default:
return errors.New("unexpected private key type")
}
}
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {
c, err := newClient(&RunnerCredentials{ c, err := newClient(&RunnerCredentials{
URL: "http://test.example.com/ci///", URL: "http://test.example.com/ci///",
...@@ -82,7 +112,7 @@ func TestClientDo(t *testing.T) { ...@@ -82,7 +112,7 @@ func TestClientDo(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
statusCode, statusText, _ := c.doJSON("test/auth", "GET", 200, nil, nil) statusCode, statusText, _, _, _ := c.doJSON("test/auth", "GET", 200, nil, nil)
assert.Equal(t, 403, statusCode, statusText) assert.Equal(t, 403, statusCode, statusText)
req := struct { req := struct {
...@@ -95,16 +125,16 @@ func TestClientDo(t *testing.T) { ...@@ -95,16 +125,16 @@ func TestClientDo(t *testing.T) {
Key string `json:"key"` Key string `json:"key"`
}{} }{}
statusCode, statusText, _ = c.doJSON("test/json", "GET", 200, nil, &res) statusCode, statusText, _, _, _ = c.doJSON("test/json", "GET", 200, nil, &res)
assert.Equal(t, 400, statusCode, statusText) assert.Equal(t, 400, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("test/json", "GET", 200, &req, nil) statusCode, statusText, _, _, _ = c.doJSON("test/json", "GET", 200, &req, nil)
assert.Equal(t, 406, statusCode, statusText) assert.Equal(t, 406, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("test/json", "GET", 200, nil, nil) statusCode, statusText, _, _, _ = c.doJSON("test/json", "GET", 200, nil, nil)
assert.Equal(t, 400, statusCode, statusText) assert.Equal(t, 400, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("test/json", "GET", 200, &req, &res) statusCode, statusText, _, _, _ = c.doJSON("test/json", "GET", 200, &req, &res)
assert.Equal(t, 200, statusCode, statusText) assert.Equal(t, 200, statusCode, statusText)
assert.Equal(t, "value", res.Key, statusText) assert.Equal(t, "value", res.Key, statusText)
} }
...@@ -116,7 +146,7 @@ func TestClientInvalidSSL(t *testing.T) { ...@@ -116,7 +146,7 @@ func TestClientInvalidSSL(t *testing.T) {
c, _ := newClient(&RunnerCredentials{ c, _ := newClient(&RunnerCredentials{
URL: s.URL, URL: s.URL,
}) })
statusCode, statusText, _ := c.doJSON("test/ok", "GET", 200, nil, nil) statusCode, statusText, _, _, _ := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, -1, statusCode, statusText) assert.Equal(t, -1, statusCode, statusText)
assert.Contains(t, statusText, "certificate signed by unknown authority") assert.Contains(t, statusText, "certificate signed by unknown authority")
} }
...@@ -137,7 +167,7 @@ func TestClientTLSCAFile(t *testing.T) { ...@@ -137,7 +167,7 @@ func TestClientTLSCAFile(t *testing.T) {
URL: s.URL, URL: s.URL,
TLSCAFile: file.Name(), TLSCAFile: file.Name(),
}) })
statusCode, statusText, certificates := c.doJSON("test/ok", "GET", 200, nil, nil) statusCode, statusText, certificates, _, _ := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, 200, statusCode, statusText) assert.Equal(t, 200, statusCode, statusText)
assert.NotEmpty(t, certificates) assert.NotEmpty(t, certificates)
} }
...@@ -157,9 +187,104 @@ func TestClientCertificateInPredefinedDirectory(t *testing.T) { ...@@ -157,9 +187,104 @@ func TestClientCertificateInPredefinedDirectory(t *testing.T) {
c, _ := newClient(&RunnerCredentials{ c, _ := newClient(&RunnerCredentials{
URL: s.URL, URL: s.URL,
}) })
statusCode, statusText, certificates := c.doJSON("test/ok", "GET", 200, nil, nil) statusCode, statusText, certificates, _, _ := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, 200, statusCode, statusText)
assert.NotEmpty(t, certificates)
}
func TestClientInvalidTLSAuth(t *testing.T) {
s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
s.TLS = new(tls.Config)
s.TLS.ClientAuth = tls.RequireAnyClientCert
s.StartTLS()
defer s.Close()
ca, err := ioutil.TempFile("", "cert_")
assert.NoError(t, err)
ca.Close()
defer os.Remove(ca.Name())
err = writeTLSCertificate(s, ca.Name())
assert.NoError(t, err)
c, _ := newClient(&RunnerCredentials{
URL: s.URL,
TLSCAFile: ca.Name(),
})
statusCode, statusText, _, _, _ := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, -1, statusCode, statusText)
assert.Contains(t, statusText, "tls: bad certificate")
}
func TestClientTLSAuth(t *testing.T) {
s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
s.TLS = new(tls.Config)
s.TLS.ClientAuth = tls.RequireAnyClientCert
s.StartTLS()
defer s.Close()
ca, err := ioutil.TempFile("", "cert_")
assert.NoError(t, err)
ca.Close()
defer os.Remove(ca.Name())
err = writeTLSCertificate(s, ca.Name())
assert.NoError(t, err)
cert, err := ioutil.TempFile("", "cert_")
assert.NoError(t, err)
cert.Close()
defer os.Remove(cert.Name())
key, err := ioutil.TempFile("", "key_")
assert.NoError(t, err)
key.Close()
defer os.Remove(key.Name())
err = writeTLSKeyPair(s, cert.Name(), key.Name())
assert.NoError(t, err)
c, _ := newClient(&RunnerCredentials{
URL: s.URL,
TLSCAFile: ca.Name(),
TLSCertFile: cert.Name(),
TLSKeyFile: key.Name(),
})
statusCode, statusText, certificates, certFile, keyFile := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, 200, statusCode, statusText)
assert.NotEmpty(t, certificates)
assert.Equal(t, cert.Name(), certFile)
assert.Equal(t, key.Name(), keyFile)
}
func TestClientTLSAuthCertificatesInPredefinedDirectory(t *testing.T) {
s := httptest.NewUnstartedServer(http.HandlerFunc(clientHandler))
s.TLS = new(tls.Config)
s.TLS.ClientAuth = tls.RequireAnyClientCert
s.StartTLS()
defer s.Close()
tempDir, err := ioutil.TempDir("", "certs")
assert.NoError(t, err)
defer os.RemoveAll(tempDir)
CertificateDirectory = tempDir
err = writeTLSCertificate(s, filepath.Join(tempDir, "127.0.0.1.crt"))
assert.NoError(t, err)
err = writeTLSKeyPair(s,
filepath.Join(tempDir, "127.0.0.1.auth.crt"),
filepath.Join(tempDir, "127.0.0.1.auth.key"))
assert.NoError(t, err)
c, _ := newClient(&RunnerCredentials{
URL: s.URL,
})
statusCode, statusText, certificates, cert, key := c.doJSON("test/ok", "GET", 200, nil, nil)
assert.Equal(t, 200, statusCode, statusText) assert.Equal(t, 200, statusCode, statusText)
assert.NotEmpty(t, certificates) assert.NotEmpty(t, certificates)
assert.NotEmpty(t, cert)
assert.NotEmpty(t, key)
} }
func TestUrlFixing(t *testing.T) { func TestUrlFixing(t *testing.T) {
...@@ -210,15 +335,15 @@ func TestClientHandleCharsetInContentType(t *testing.T) { ...@@ -210,15 +335,15 @@ func TestClientHandleCharsetInContentType(t *testing.T) {
Key string `json:"key"` Key string `json:"key"`
}{} }{}
statusCode, statusText, _ := c.doJSON("with-charset", "GET", 200, nil, &res) statusCode, statusText, _, _, _ := c.doJSON("with-charset", "GET", 200, nil, &res)
assert.Equal(t, 200, statusCode, statusText) assert.Equal(t, 200, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("without-charset", "GET", 200, nil, &res) statusCode, statusText, _, _, _ = c.doJSON("without-charset", "GET", 200, nil, &res)
assert.Equal(t, 200, statusCode, statusText) assert.Equal(t, 200, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("without-json", "GET", 200, nil, &res) statusCode, statusText, _, _, _ = c.doJSON("without-json", "GET", 200, nil, &res)
assert.Equal(t, -1, statusCode, statusText) assert.Equal(t, -1, statusCode, statusText)
statusCode, statusText, _ = c.doJSON("invalid-header", "GET", 200, nil, &res) statusCode, statusText, _, _, _ = c.doJSON("invalid-header", "GET", 200, nil, &res)
assert.Equal(t, -1, statusCode, statusText) assert.Equal(t, -1, statusCode, statusText)
} }
...@@ -33,7 +33,7 @@ func (n *GitLabClient) getClient(credentials requestCredentials) (c *client, err ...@@ -33,7 +33,7 @@ func (n *GitLabClient) getClient(credentials requestCredentials) (c *client, err
if n.clients == nil { if n.clients == nil {
n.clients = make(map[string]*client) n.clients = make(map[string]*client)
} }
key := fmt.Sprintf("%s_%s_%s", credentials.GetURL(), credentials.GetToken(), credentials.GetTLSCAFile()) key := fmt.Sprintf("%s_%s_%s_%s", credentials.GetURL(), credentials.GetToken(), credentials.GetTLSCAFile(), credentials.GetTLSCertFile())
c = n.clients[key] c = n.clients[key]
if c == nil { if c == nil {
c, err = newClient(credentials) c, err = newClient(credentials)
...@@ -83,10 +83,10 @@ func (n *GitLabClient) doRaw(credentials requestCredentials, method, uri string, ...@@ -83,10 +83,10 @@ func (n *GitLabClient) doRaw(credentials requestCredentials, method, uri string,
return c.do(uri, method, request, requestType, headers) return c.do(uri, method, request, requestType, headers)
} }
func (n *GitLabClient) doJSON(credentials requestCredentials, method, uri string, statusCode int, request interface{}, response interface{}) (int, string, string) { func (n *GitLabClient) doJSON(credentials requestCredentials, method, uri string, statusCode int, request interface{}, response interface{}) (int, string, string, string, string) {
c, err := n.getClient(credentials) c, err := n.getClient(credentials)
if err != nil { if err != nil {
return clientError, err.Error(), "" return clientError, err.Error(), "", "", ""
} }
return c.doJSON(uri, method, statusCode, request, response) return c.doJSON(uri, method, statusCode, request, response)
...@@ -104,7 +104,7 @@ func (n *GitLabClient) RegisterRunner(runner common.RunnerCredentials, descripti ...@@ -104,7 +104,7 @@ func (n *GitLabClient) RegisterRunner(runner common.RunnerCredentials, descripti
} }
var response common.RegisterRunnerResponse var response common.RegisterRunnerResponse
result, statusText, _ := n.doJSON(&runner, "POST", "runners", 201, &request, &response) result, statusText, _, _, _ := n.doJSON(&runner, "POST", "runners", 201, &request, &response)
switch result { switch result {
case 201: case 201:
...@@ -127,7 +127,7 @@ func (n *GitLabClient) VerifyRunner(runner common.RunnerCredentials) bool { ...@@ -127,7 +127,7 @@ func (n *GitLabClient) VerifyRunner(runner common.RunnerCredentials) bool {
Token: runner.Token, Token: runner.Token,
} }
result, statusText, _ := n.doJSON(&runner, "POST", "runners/verify", 200, &request, nil) result, statusText, _, _, _ := n.doJSON(&runner, "POST", "runners/verify", 200, &request, nil)
switch result { switch result {
case 200: case 200:
...@@ -151,7 +151,7 @@ func (n *GitLabClient) UnregisterRunner(runner common.RunnerCredentials) bool { ...@@ -151,7 +151,7 @@ func (n *GitLabClient) UnregisterRunner(runner common.RunnerCredentials) bool {
Token: runner.Token, Token: runner.Token,
} }
result, statusText, _ := n.doJSON(&runner, "DELETE", "runners", 200, &request, nil) result, statusText, _, _, _ := n.doJSON(&runner, "DELETE", "runners", 200, &request, nil)
switch result { switch result {
case 204: case 204:
...@@ -169,6 +169,20 @@ func (n *GitLabClient) UnregisterRunner(runner common.RunnerCredentials) bool { ...@@ -169,6 +169,20 @@ func (n *GitLabClient) UnregisterRunner(runner common.RunnerCredentials) bool {
} }
} }
func addTLSAuth(response *common.JobResponse, cert string, key string) {
if cert != "" && key != "" {
data, err := ioutil.ReadFile(cert)
if err == nil {
response.TLSAuthCert = string(data)
}
data, err = ioutil.ReadFile(key)
if err == nil {
response.TLSAuthKey = string(data)
}
}
}
func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobResponse, bool) { func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobResponse, bool) {
request := common.JobRequest{ request := common.JobRequest{
Info: n.getRunnerVersion(config), Info: n.getRunnerVersion(config),
...@@ -177,7 +191,7 @@ func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobRespon ...@@ -177,7 +191,7 @@ func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobRespon
} }
var response common.JobResponse var response common.JobResponse
result, statusText, certificates := n.doJSON(&config.RunnerCredentials, "POST", "jobs/request", 201, &request, &response) result, statusText, caChain, cert, key := n.doJSON(&config.RunnerCredentials, "POST", "jobs/request", 201, &request, &response)
switch result { switch result {
case 201: case 201:
...@@ -185,7 +199,8 @@ func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobRespon ...@@ -185,7 +199,8 @@ func (n *GitLabClient) RequestJob(config common.RunnerConfig) (*common.JobRespon
"job": strconv.Itoa(response.ID), "job": strconv.Itoa(response.ID),
"repo_url": response.RepoCleanURL(), "repo_url": response.RepoCleanURL(),
}).Println("Checking for jobs...", "received") }).Println("Checking for jobs...", "received")
response.TLSCAChain = certificates response.TLSCAChain = caChain
addTLSAuth(&response, cert, key)
return &response, true return &response, true
case 403: case 403:
config.Log().Errorln("Checking for jobs...", "forbidden") config.Log().Errorln("Checking for jobs...", "forbidden")
...@@ -212,7 +227,7 @@ func (n *GitLabClient) UpdateJob(config common.RunnerConfig, jobCredentials *com ...@@ -212,7 +227,7 @@ func (n *GitLabClient) UpdateJob(config common.RunnerConfig, jobCredentials *com
log := config.Log().WithField("job", id) log := config.Log().WithField("job", id)
result, statusText, _ := n.doJSON(&config.RunnerCredentials, "PUT", fmt.Sprintf("jobs/%d", id), 200, &request, nil) result, statusText, _, _, _ := n.doJSON(&config.RunnerCredentials, "PUT", fmt.Sprintf("jobs/%d", id), 200, &request, nil)
switch result { switch result {
case 200: case 200:
log.Debugln("Submitting job to coordinator...", "ok") log.Debugln("Submitting job to coordinator...", "ok")
......
...@@ -41,12 +41,26 @@ func TestClients(t *testing.T) { ...@@ -41,12 +41,26 @@ func TestClients(t *testing.T) {
URL: "http://test/", URL: "http://test/",
TLSCAFile: "ca_file", TLSCAFile: "ca_file",
}) })
c6, c6err := c.getClient(&brokenCredentials) c6, _ := c.getClient(&RunnerCredentials{
URL: "http://test/",
TLSCAFile: "ca_file",
TLSCertFile: "cert_file",
TLSKeyFile: "key_file",
})
c7, _ := c.getClient(&RunnerCredentials{
URL: "http://test/",
TLSCAFile: "ca_file",
TLSCertFile: "cert_file",
TLSKeyFile: "key_file2",
})
c8, c8err := c.getClient(&brokenCredentials)
assert.NotEqual(t, c1, c2) assert.NotEqual(t, c1, c2)
assert.NotEqual(t, c1, c4) assert.NotEqual(t, c1, c4)
assert.Equal(t, c4, c5) assert.Equal(t, c4, c5)
assert.Nil(t, c6) assert.NotEqual(t, c5, c6)
assert.Error(t, c6err) assert.Equal(t, c6, c7)
assert.Nil(t, c8)
assert.Error(t, c8err)
} }
func testRegisterRunnerHandler(w http.ResponseWriter, r *http.Request, t *testing.T) { func testRegisterRunnerHandler(w http.ResponseWriter, r *http.Request, t *testing.T) {
......
...@@ -28,15 +28,9 @@ func (b *AbstractShell) writeExports(w ShellWriter, info common.ShellScriptInfo) ...@@ -28,15 +28,9 @@ func (b *AbstractShell) writeExports(w ShellWriter, info common.ShellScriptInfo)
} }
} }
func (b *AbstractShell) writeTLSCAInfo(w ShellWriter, build *common.Build, key string) { func (b *AbstractShell) writeGitExports(w ShellWriter, info common.ShellScriptInfo) {
if build.TLSCAChain != "" { for _, variable := range info.Build.GetGitTLSVariables() {
w.Variable(common.JobVariable{ w.Variable(variable)
Key: key,
Value: build.TLSCAChain,
Public: true,
Internal: true,
File: true,
})
} }
} }
...@@ -293,7 +287,7 @@ func (b *AbstractShell) writeSubmoduleUpdateCmds(w ShellWriter, info common.Shel ...@@ -293,7 +287,7 @@ func (b *AbstractShell) writeSubmoduleUpdateCmds(w ShellWriter, info common.Shel
func (b *AbstractShell) writeGetSourcesScript(w ShellWriter, info common.ShellScriptInfo) (err error) { func (b *AbstractShell) writeGetSourcesScript(w ShellWriter, info common.ShellScriptInfo) (err error) {
b.writeExports(w, info) b.writeExports(w, info)
b.writeTLSCAInfo(w, info.Build, "GIT_SSL_CAINFO") b.writeGitExports(w, info)
if info.PreCloneScript != "" && info.Build.GetGitStrategy() != common.GitNone { if info.PreCloneScript != "" && info.Build.GetGitStrategy() != common.GitNone {
b.writeCommands(w, info.PreCloneScript) b.writeCommands(w, info.PreCloneScript)
...@@ -313,7 +307,6 @@ func (b *AbstractShell) writeGetSourcesScript(w ShellWriter, info common.ShellSc ...@@ -313,7 +307,6 @@ func (b *AbstractShell) writeGetSourcesScript(w ShellWriter, info common.ShellSc
func (b *AbstractShell) writeRestoreCacheScript(w ShellWriter, info common.ShellScriptInfo) (err error) { func (b *AbstractShell) writeRestoreCacheScript(w ShellWriter, info common.ShellScriptInfo) (err error) {
b.writeExports(w, info) b.writeExports(w, info)
b.writeCdBuildDir(w, info) b.writeCdBuildDir(w, info)
b.writeTLSCAInfo(w, info.Build, "CI_SERVER_TLS_CA_FILE")
// Try to restore from main cache, if not found cache for master // Try to restore from main cache, if not found cache for master
b.cacheExtractor(w, info) b.cacheExtractor(w, info)
...@@ -323,7 +316,6 @@ func (b *AbstractShell) writeRestoreCacheScript(w ShellWriter, info common.Shell ...@@ -323,7 +316,6 @@ func (b *AbstractShell) writeRestoreCacheScript(w ShellWriter, info common.Shell
func (b *AbstractShell) writeDownloadArtifactsScript(w ShellWriter, info common.ShellScriptInfo) (err error) { func (b *AbstractShell) writeDownloadArtifactsScript(w ShellWriter, info common.ShellScriptInfo) (err error) {
b.writeExports(w, info) b.writeExports(w, info)
b.writeCdBuildDir(w, info) b.writeCdBuildDir(w, info)
b.writeTLSCAInfo(w, info.Build, "CI_SERVER_TLS_CA_FILE")
// Process all artifacts // Process all artifacts
b.downloadAllArtifacts(w, info) b.downloadAllArtifacts(w, info)
...@@ -503,7 +495,6 @@ func (b *AbstractShell) writeAfterScript(w ShellWriter, info common.ShellScriptI ...@@ -503,7 +495,6 @@ func (b *AbstractShell) writeAfterScript(w ShellWriter, info common.ShellScriptI
func (b *AbstractShell) writeArchiveCacheScript(w ShellWriter, info common.ShellScriptInfo) (err error) { func (b *AbstractShell) writeArchiveCacheScript(w ShellWriter, info common.ShellScriptInfo) (err error) {
b.writeExports(w, info) b.writeExports(w, info)
b.writeCdBuildDir(w, info) b.writeCdBuildDir(w, info)
b.writeTLSCAInfo(w, info.Build, "CI_SERVER_TLS_CA_FILE")
// Find cached files and archive them // Find cached files and archive them
b.cacheArchiver(w, info) b.cacheArchiver(w, info)
...@@ -513,7 +504,6 @@ func (b *AbstractShell) writeArchiveCacheScript(w ShellWriter, info common.Shell ...@@ -513,7 +504,6 @@ func (b *AbstractShell) writeArchiveCacheScript(w ShellWriter, info common.Shell
func (b *AbstractShell) writeUploadArtifactsScript(w ShellWriter, info common.ShellScriptInfo) (err error) { func (b *AbstractShell) writeUploadArtifactsScript(w ShellWriter, info common.ShellScriptInfo) (err error) {
b.writeExports(w, info) b.writeExports(w, info)
b.writeCdBuildDir(w, info) b.writeCdBuildDir(w, info)
b.writeTLSCAInfo(w, info.Build, "CI_SERVER_TLS_CA_FILE")
// Upload artifacts // Upload artifacts
b.uploadArtifacts(w, info) b.uploadArtifacts(w, info)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment