Commit a5260fe5 authored by Daniel P. Berrange's avatar Daniel P. Berrange

Rewrite connection auth method to be more flexible

The NewConnectWithAuth method is hardcoded to only
support 2 credential types. Rewrite it to expose the
full callback facility from the C API
Signed-off-by: default avatarDaniel P. Berrange <berrange@redhat.com>
parent d84c1034
......@@ -297,30 +297,73 @@ func NewConnect(uri string) (*Connect, error) {
return &Connect{ptr: ptr}, nil
}
func NewConnectWithAuth(uri string, username string, password string) (*Connect, error) {
type ConnectCredential struct {
Type ConnectCredentialType
Prompt string
Challenge string
DefResult string
Result string
ResultLen int
}
type ConnectAuthCallback func(creds []*ConnectCredential)
type ConnectAuth struct {
CredType []ConnectCredentialType
Callback ConnectAuthCallback
}
//export connectAuthCallback
func connectAuthCallback(ccredlist C.virConnectCredentialPtr, ncred C.uint, callbackID C.int) C.int {
cred := make([]*ConnectCredential, int(ncred))
for i := 0; i < int(ncred); i++ {
ccred := (C.virConnectCredentialPtr)(unsafe.Pointer((uintptr)(unsafe.Pointer(ccredlist)) + (unsafe.Sizeof(*ccredlist) * uintptr(i))))
cred[i] = &ConnectCredential{
Type: ConnectCredentialType(ccred._type),
Prompt: C.GoString(ccred.prompt),
Challenge: C.GoString(ccred.challenge),
DefResult: C.GoString(ccred.defresult),
ResultLen: -1,
}
}
callbackEntry := getCallbackId(int(callbackID))
callback, ok := callbackEntry.(ConnectAuthCallback)
if !ok {
panic("Unexpected callback type")
}
callback(cred)
for i := 0; i < int(ncred); i++ {
ccred := (C.virConnectCredentialPtr)(unsafe.Pointer((uintptr)(unsafe.Pointer(ccredlist)) + (unsafe.Sizeof(*ccredlist) * uintptr(i))))
if cred[i].ResultLen >= 0 {
ccred.result = C.CString(cred[i].Result)
ccred.resultlen = C.uint(cred[i].ResultLen)
}
}
return 0
}
func NewConnectWithAuth(uri string, auth *ConnectAuth, flags uint32) (*Connect, error) {
var cUri *C.char
authMechs := C.authMechs()
defer C.free(unsafe.Pointer(authMechs))
cUsername := C.CString(username)
defer C.free(unsafe.Pointer(cUsername))
cPassword := C.CString(password)
defer C.free(unsafe.Pointer(cPassword))
cbData := C.authData(cUsername, C.uint(len(username)), cPassword, C.uint(len(password)))
defer C.free(unsafe.Pointer(cbData))
ccredtype := make([]C.int, len(auth.CredType))
auth := C.virConnectAuth{
credtype: authMechs,
ncredtype: C.uint(2),
cb: C.virConnectAuthCallbackPtr(unsafe.Pointer(C.authCb)),
cbdata: unsafe.Pointer(cbData),
for i := 0; i < len(auth.CredType); i++ {
ccredtype[i] = C.int(auth.CredType[i])
}
if uri != "" {
cUri = C.CString(uri)
defer C.free(unsafe.Pointer(cUri))
}
ptr := C.virConnectOpenAuth(cUri, (*C.struct__virConnectAuth)(unsafe.Pointer(&auth)), C.uint(0))
callbackID := registerCallbackId(auth.Callback)
ptr := C.virConnectOpenAuthWrap(cUri, &ccredtype[0], C.uint(len(auth.CredType)), C.int(callbackID), C.uint(flags))
freeCallbackId(callbackID)
if ptr == nil {
return nil, GetLastError()
}
......
......@@ -4,8 +4,6 @@ package libvirt
#cgo pkg-config: libvirt
#include <libvirt/libvirt.h>
#include <libvirt/virterror.h>
#include <stdlib.h>
#include <string.h>
#include "connect_cfuncs.h"
#include "callbacks_cfuncs.h"
......@@ -15,48 +13,32 @@ void closeCallback_cgo(virConnectPtr conn, int reason, void *opaque)
closeCallback(conn, reason, (long)opaque);
}
int authCb(virConnectCredentialPtr cred, unsigned int ncred, void *cbdata)
int virConnectRegisterCloseCallback_cgo(virConnectPtr c, virConnectCloseFunc cb, long goCallbackId)
{
int i;
auth_cb_data *data = (auth_cb_data*)cbdata;
for (i = 0; i < ncred; i++) {
if (cred[i].type == VIR_CRED_AUTHNAME) {
cred[i].result = strndup(data->username, data->username_len);
if (cred[i].result == NULL)
return -1;
cred[i].resultlen = strlen(cred[i].result);
}
else if (cred[i].type == VIR_CRED_PASSPHRASE) {
cred[i].result = strndup(data->passphrase, data->passphrase_len);
if (cred[i].result == NULL)
return -1;
cred[i].resultlen = strlen(cred[i].result);
}
}
return 0;
void *id = (void*)goCallbackId;
return virConnectRegisterCloseCallback(c, cb, id, freeGoCallback_cgo);
}
auth_cb_data* authData(char* username, uint username_len, char* passphrase, uint passphrase_len) {
auth_cb_data * data = malloc(sizeof(auth_cb_data));
data->username = username;
data->username_len = username_len;
data->passphrase = passphrase;
data->passphrase_len = passphrase_len;
return data;
}
#include <stdio.h>
int* authMechs() {
int* authMechs = malloc(2*sizeof(VIR_CRED_AUTHNAME));
authMechs[0] = VIR_CRED_AUTHNAME;
authMechs[1] = VIR_CRED_PASSPHRASE;
return authMechs;
extern int connectAuthCallback(virConnectCredentialPtr, unsigned int, int);
int connectAuthCallback_cgo(virConnectCredentialPtr cred, unsigned int ncred, void *cbdata)
{
int *callbackID = cbdata;
return connectAuthCallback(cred, ncred, *callbackID);
}
int virConnectRegisterCloseCallback_cgo(virConnectPtr c, virConnectCloseFunc cb, long goCallbackId)
virConnectPtr virConnectOpenAuthWrap(const char *name, int *credtype, uint ncredtype, int callbackID, unsigned int flags)
{
void *id = (void*)goCallbackId;
return virConnectRegisterCloseCallback(c, cb, id, freeGoCallback_cgo);
virConnectAuth auth = {
.credtype = credtype,
.ncredtype = ncredtype,
.cb = connectAuthCallback_cgo,
.cbdata = &callbackID,
};
return virConnectOpenAuth(name, &auth, flags);
}
*/
......
......@@ -3,15 +3,6 @@
void closeCallback_cgo(virConnectPtr conn, int reason, void *opaque);
int virConnectRegisterCloseCallback_cgo(virConnectPtr c, virConnectCloseFunc cb, long goCallbackId);
typedef struct auth_cb_data {
char* username;
uint username_len;
char* passphrase;
uint passphrase_len;
} auth_cb_data;
int* authMechs();
int authCb(virConnectCredentialPtr cred, unsigned int ncred, void *cbdata);
auth_cb_data* authData(char* username, uint username_len, char* passphrase, uint passphrase_len);
virConnectPtr virConnectOpenAuthWrap(const char *name, int *credtype, uint ncredtype, int callbackID, unsigned int flags);
#endif /* GO_LIBVIRT_H */
......@@ -139,7 +139,24 @@ func TestSetKeepalive(t *testing.T) {
}
func TestConnectionWithAuth(t *testing.T) {
conn, err := NewConnectWithAuth("test+tcp://127.0.0.1/default", "user", "pass")
callback := func(creds []*ConnectCredential) {
for _, cred := range creds {
if cred.Type == CRED_AUTHNAME {
cred.Result = "user"
cred.ResultLen = len(cred.Result)
} else if cred.Type == CRED_PASSPHRASE {
cred.Result = "pass"
cred.ResultLen = len(cred.Result)
}
}
}
auth := &ConnectAuth{
CredType: []ConnectCredentialType{
CRED_AUTHNAME, CRED_PASSPHRASE,
},
Callback: callback,
}
conn, err := NewConnectWithAuth("test+tcp://127.0.0.1/default", auth, 0)
if err != nil {
t.Error(err)
return
......@@ -155,7 +172,24 @@ func TestConnectionWithAuth(t *testing.T) {
}
func TestConnectionWithWrongCredentials(t *testing.T) {
conn, err := NewConnectWithAuth("test+tcp://127.0.0.1/default", "user", "wrongpass")
callback := func(creds []*ConnectCredential) {
for _, cred := range creds {
if cred.Type == CRED_AUTHNAME {
cred.Result = "user"
cred.ResultLen = len(cred.Result)
} else if cred.Type == CRED_PASSPHRASE {
cred.Result = "wrongpass"
cred.ResultLen = len(cred.Result)
}
}
}
auth := &ConnectAuth{
CredType: []ConnectCredentialType{
CRED_AUTHNAME, CRED_PASSPHRASE,
},
Callback: callback,
}
conn, err := NewConnectWithAuth("test+tcp://127.0.0.1/default", auth, 0)
if err == nil {
conn.CloseConnection()
t.Error(err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment