Commit 6b228730 authored by Sybren A. Stüvel's avatar Sybren A. Stüvel

Added server signature verification

parent e61590f6
package bunqapi
import (
"crypto/rsa"
"io/ioutil"
"net/url"
"path/filepath"
......@@ -18,7 +19,8 @@ type Credentials struct {
InstallationToken string `yaml:"installationToken"`
ServerPublicKey string `yaml:"serverPublicKey"`
DeviceID int `yaml:"deviceID"`
serverPublicKey *rsa.PublicKey
DeviceID int `yaml:"deviceID"`
}
func LoadCredentials() Credentials {
......@@ -55,6 +57,10 @@ func LoadCredentials() Credentials {
}
logger.WithField("url", creds.apiURL.String()).Debug("set API URL based on API mode")
if creds.ServerPublicKey != "" {
creds.serverPublicKey = parsePublicRSAKeyString(creds.ServerPublicKey)
}
return creds
}
......
......@@ -102,7 +102,8 @@ func (c *Client) DoRequest(method, endpoint string, payload interface{}, respons
logger.WithError(err).Fatal("unable to perform HTTP call")
}
// TODO: verify server signature
c.VerifyResponse(resp)
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
logger.WithError(err).Fatal("unable to read HTTP body")
......
......@@ -30,6 +30,7 @@ const (
headerXBunqGeolocation = "X-Bunq-Geolocation"
headerXBunqLanguage = "X-Bunq-Language"
headerXBunqRegion = "X-Bunq-Region"
headerXBunqServerSignature = "X-Bunq-Server-Signature"
)
// Default header values
......
package bunqapi
import "github.com/sirupsen/logrus"
type deviceServerRequest struct {
// The description of the DeviceServer. This is only for your own reference when reading the DeviceServer again.
Description string `json:"description"`
......@@ -20,7 +18,7 @@ type deviceServerResponse struct {
func (c *Client) CheckDeviceServer(description string, permittedIPs []string) {
if c.creds.DeviceID != 0 {
logrus.WithField("deviceID", c.creds.DeviceID).Debug("device ID is known")
log.WithField("deviceID", c.creds.DeviceID).Debug("device ID is known")
return
}
c.PostDeviceServer(description, permittedIPs)
......@@ -42,6 +40,6 @@ func (c *Client) PostDeviceServer(description string, permittedIPs []string) {
MergeStructs(wrappedResponse.Response, &response)
c.creds.DeviceID = response.ID.ID
logrus.WithField("id", c.creds.DeviceID).Info("registered device device ID")
log.WithField("id", c.creds.DeviceID).Info("registered device device ID")
c.creds.Save()
}
......@@ -39,7 +39,7 @@ func main() {
bunqapi.MergeStructs(fromBunq.Response, &merged)
out, err := json.MarshalIndent(merged, "", " ")
if err != nil {
logrus.WithError(err).Fatal("unable to marshal result to JSON")
log.WithError(err).Fatal("unable to marshal result to JSON")
}
os.Stdout.Write(out)
os.Stdout.WriteString("\n")
......
......@@ -2,8 +2,6 @@ package bunqapi
import (
"fmt"
"github.com/sirupsen/logrus"
)
type wrappedMonetaryAccountBankResponse struct {
......@@ -54,7 +52,7 @@ func (c *Client) GetMonetaryAccountBankList() {
url := fmt.Sprintf("user/%d/monetary-account-bank", c.session.UserPerson.ID)
errResp := c.DoRequest("GET", url, nil, &wrappedResponse)
if errResp != nil {
logrus.WithFields(errResp.LogFields()).Fatal("error performing session creation request")
log.WithFields(errResp.LogFields()).Fatal("error performing session creation request")
}
}
......@@ -63,11 +61,11 @@ func (c *Client) CreateMonetaryAccountBank(account MonetaryAccountBank) {
url := fmt.Sprintf("user/%d/monetary-account-bank", c.session.UserPerson.ID)
errResp := c.DoRequest("POST", url, &account, &wrappedResponse)
if errResp != nil {
logrus.WithFields(errResp.LogFields()).Fatal("error performing session creation request")
log.WithFields(errResp.LogFields()).Fatal("error performing session creation request")
}
response := monetaryAccountBankCreateResponse{}
MergeStructs(wrappedResponse.Response, &response)
logrus.WithField("bankAccountID", response.ID.ID).Info("bank account created")
log.WithField("bankAccountID", response.ID.ID).Info("bank account created")
}
......@@ -45,6 +45,21 @@ func readPEM(filename, pemType string) ([]byte, error) {
}
}
func parsePEM(pemString, pemType string) ([]byte, error) {
pemdata := []byte(pemString)
var block *pem.Block
for {
block, pemdata = pem.Decode(pemdata)
if block == nil {
log.WithField("expectedHeader", pemType).Panic("unable to find expected block of PEM data")
}
if block.Type == pemType {
return block.Bytes, nil
}
}
}
func generateNewPrivateKey() *rsa.PrivateKey {
privkey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
......@@ -82,3 +97,24 @@ func ensureOwnRSAKey() *rsa.PrivateKey {
return privkey
}
func parsePublicRSAKeyString(pemString string) *rsa.PublicKey {
keyBytes, err := parsePEM(pemString, "PUBLIC KEY")
if err != nil {
log.WithError(err).Error("unable to parse public key, assuming it's not there")
return nil
}
pubkey, err := x509.ParsePKIXPublicKey(keyBytes)
if err != nil {
log.WithError(err).Fatal("unable to parse public RSA key")
}
rsaPubkey, ok := pubkey.(*rsa.PublicKey)
if !ok {
log.Fatal("public key is not an RSA key")
}
log.WithField("keySizeBits", rsaPubkey.Size()*8).Debug("loaded public RSA key")
return rsaPubkey
}
......@@ -45,7 +45,7 @@ func (c *Client) SessionStart() {
c.session.UserCompany = response.UserCompany
c.session.UserPerson = response.UserPerson
logrus.WithFields(logrus.Fields{
log.WithFields(logrus.Fields{
"firstName": c.session.UserPerson.FirstName,
"lastName": c.session.UserPerson.LastName,
"session": c.session.Token,
......
......@@ -10,6 +10,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/textproto"
"sort"
"strings"
......@@ -22,9 +23,9 @@ var headersToSign = map[string]bool{
headerUserAgent: true,
}
func sortedHeaderKeys(r *http.Request) []string {
func sortedHeaderKeys(header http.Header) []string {
headerKeys := []string{}
for key := range r.Header {
for key := range header {
headerKeys = append(headerKeys, key)
}
sort.Strings(headerKeys)
......@@ -43,7 +44,7 @@ func (c *Client) SignRequest(r *http.Request) error {
}
hash(fmt.Sprintf("%s %s\n", r.Method, r.URL.Path))
headerKeys := sortedHeaderKeys(r)
headerKeys := sortedHeaderKeys(r.Header)
for _, key := range headerKeys {
if !headersToSign[key] && !strings.HasPrefix(strings.ToLower(key), "x-bunq-") {
continue
......@@ -85,3 +86,60 @@ func (c *Client) SignRequest(r *http.Request) error {
r.Header.Set(headerXBunqClientSignature, encodedSig)
return nil
}
func (c *Client) VerifyResponse(r *http.Response) error {
// TODO: merge common code between this function and SignRequest().
serverSignatureB64 := r.Header.Get(headerXBunqServerSignature)
if serverSignatureB64 == "" {
log.Panic("this response has not been signed")
}
hasher := sha256.New()
hash := func(value string) {
log.WithField("value", value).Debug("writing to hasher")
hasher.Write([]byte(value))
}
hash(fmt.Sprintf("%d\n", r.StatusCode))
headerKeys := sortedHeaderKeys(r.Header)
serverSigLowerKey := strings.ToLower(headerXBunqServerSignature)
for _, key := range headerKeys {
lowerKey := strings.ToLower(key)
if !strings.HasPrefix(lowerKey, "x-bunq-") || lowerKey == serverSigLowerKey {
continue
}
for _, value := range r.Header[key] {
hash(fmt.Sprintf("%s: %s\n", textproto.CanonicalMIMEHeaderKey(key), value))
}
}
hash("\n")
bodyBytes, err := ioutil.ReadAll(r.Body)
if err != nil {
log.WithError(err).Panic("unable to read response body for signing")
}
hasher.Write(bodyBytes)
log.WithField("value", string(bodyBytes)).Debug("writing body to hasher")
r.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes))
sum := hasher.Sum(nil)
// Verify the signature just to be sure ;-)
serverSignature, err := base64.StdEncoding.DecodeString(serverSignatureB64)
if err != nil {
log.WithError(err).Fatal("unable to base64-decode the server signature")
}
if err := rsa.VerifyPKCS1v15(c.creds.serverPublicKey, crypto.SHA256, sum, serverSignature); err != nil {
log.WithError(err).Fatal("unable to verify server signature with public RSA key")
}
encodedSig := base64.StdEncoding.EncodeToString(serverSignature)
log.WithFields(logrus.Fields{
"shaSum": fmt.Sprintf("%x", sum),
"signature": encodedSig,
}).Debug("verified server response signature")
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment