Commit c118757a authored by Matthew Stobbs's avatar Matthew Stobbs
Browse files

Add Authentication logic

- Able to auth in graphql
- TODO: Persist token for future requests
parent 14e3a76e
settings:
files:
outputs:
status: false
path: ""
name: .r.outputs.log
logs:
status: false
path: ""
name: .r.logs.log
errors:
status: false
path: ""
name: .r.errors.log
legacy:
force: false
interval: 0s
server:
status: true
open: true
port: 5002
host: localhost
schema:
- name: go-gallery
path: .
args:
- main.go
commands:
run:
status: true
watcher:
extensions:
- go
paths:
- /
ignore:
paths:
- .git
- .realize
- vendor
......@@ -14,16 +14,16 @@ import (
var (
selectAllImage = q.Prepare(
q.Select(
`id`,
`filename`,
`storagepath`,
`sha256`,
`size`,
`height`,
`width`,
`owner_id`,
`created_at`,
`updated_at`,
`images.id`,
`images.filename`,
`images.storagepath`,
`images.sha256`,
`images.size`,
`images.height`,
`images.width`,
`images.owner_id`,
`images.created_at`,
`images.updated_at`,
),
)
createImageQuery = q.QueryBuilder(
......@@ -41,6 +41,10 @@ var (
`updated_at`,
),
)
getAllImagesQuery = q.QueryBuilder(
selectAllImage,
q.From(`images`),
)
getImageByIDQuery = q.QueryBuilder(
selectAllImage,
q.From(`images`),
......@@ -97,6 +101,11 @@ func (d *Db) GetImageBySha256(sum string) (*models.Image, error) {
return &i[0], nil
}
// GetAllImages returns every image
func (d *Db) GetAllImages() ([]models.Image, error) {
return d.GetImageByQuery(getAllImagesQuery)
}
// GetImageByQuery retrieves models using a query
func (d *Db) GetImageByQuery(query string, arg ...interface{}) ([]models.Image, error) {
stmt, err := d.Prepare(query)
......@@ -106,10 +115,10 @@ func (d *Db) GetImageByQuery(query string, arg ...interface{}) ([]models.Image,
defer stmt.Close()
rows, err := stmt.Query(arg...)
defer rows.Close()
if err != nil {
return nil, NewUnknownError(err)
}
defer rows.Close()
var images []models.Image
var createdat, updatedat time.Time
......@@ -160,9 +169,6 @@ func (d *Db) checkIfImageExistsDB(ownerid uuid.UUID, sha256 string) error {
// CreateImage generates missing data and inserts an entry in the database
func (d *Db) CreateImage(i *models.Image, data []byte) error {
i.CreatedAt.Set(time.Now())
i.UpdatedAt = i.CreatedAt
stmt, err := d.Prepare(createImageQuery)
defer stmt.Close()
if err != nil {
......
package postgres
import (
"errors"
"log"
"time"
uuid "github.com/satori/go.uuid"
"gitlab.com/stobbsm/go-gallery/api/models"
q "gitlab.com/stobbsm/go-querybuilder"
)
// Queries needed
var (
checkShareExistsQuery = q.QueryBuilder(
q.Select(
`image_share.id`,
),
q.From(`image_share`),
q.WhereEq(`image_share.image_id`),
q.AndEq(`image_share.user_id`),
)
createImageShare = q.QueryBuilder(
q.Insert(
`image_share`,
`id`,
`image_id`,
`user_id`,
`shared`,
`created_at`,
`updated_at`,
),
)
getSharedImagesByUserID = q.QueryBuilder(
selectAllImage,
q.From(`images`),
q.InnerJoin(`image_share`),
q.On(`images.id`, `image_share.image_id`),
q.InnerJoin(`users`),
q.On(`users.id`, `image_share.user_id`),
q.WhereEq(`users.id`),
q.AndEq(`image_share.shared`),
)
getSharedUsersByImageID = q.QueryBuilder(
prepareSelectUser,
prepareFromUser,
q.InnerJoin(`users`),
q.On(`image_share.user_id`, `users.id`),
q.InnerJoin(`images`),
q.On(`image_share.image_id`, `images.id`),
q.WhereEq(`images.id`),
q.AndEq(`image_share.shared`),
)
)
// GetSharedImagesByUserID returns the list of images by the UserID
func (d *Db) GetSharedImagesByUserID(userid uuid.UUID) ([]models.Image, error) {
images, err := d.GetImageByQuery(getSharedImagesByUserID, userid, true)
if err != nil {
return nil, err
}
return images, nil
}
func (d *Db) getSharedUsersByImageID(imageid uuid.UUID) ([]models.User, error) {
users, err := d.GetUserByQuery(getSharedUsersByImageID, imageid, true)
if err != nil {
return nil, err
}
return users, nil
}
// checkShareExists checks for the existence of a share
func (d *Db) checkShareExists(imageid, userid uuid.UUID) error {
stmt, err := d.Prepare(checkShareExistsQuery)
defer stmt.Close()
if err != nil {
return NewUnknownError(err)
}
rows, err := stmt.Query(imageid, userid)
defer rows.Close()
if err != nil {
return NewUnknownError(err)
}
var s models.ImageShare
for rows.Next() {
log.Println(`Scanning share id`)
err = rows.Scan(
&s.ID,
)
if err != nil {
return NewModelNotFound(err)
}
}
if s.ID == uuid.Nil {
return NewModelNotFound(errors.New(`Share not found`))
}
return NewModelFound(errors.New(`Share Exists`))
}
// ShareImage inserts an image share
func (d *Db) ShareImage(imageshare *models.ImageShare) error {
if err := d.checkShareExists(imageshare.ImageID, imageshare.UserID); err != nil {
switch err.(type) {
case *ModelNotFound:
log.Println(err)
case *ModelFound:
return err
default:
return NewUnknownError(err)
}
}
stmt, err := d.Prepare(createImageShare)
defer stmt.Close()
if err != nil {
return NewUnknownError(err)
}
imageshare.ID = models.GenUUID()
imageshare.CreatedAt.Set(time.Now())
imageshare.UpdatedAt = imageshare.CreatedAt
_, err = stmt.Exec(
imageshare.ID,
imageshare.ImageID,
imageshare.UserID,
imageshare.Shared,
imageshare.CreatedAt.GetTime(),
imageshare.UpdatedAt.GetTime(),
)
if err != nil {
return NewUnknownError(err)
}
return nil
}
......@@ -102,3 +102,26 @@ func NewUnknownError(err error) *UnknownError {
func (u *UnknownError) Error() string {
return fmt.Sprintf(`Unknown Error: %s:%d : %s`, u.file, u.line, u.err)
}
// InvalidError error type
type InvalidError struct {
err error
file string
line int
}
// NewInvalidError generates a meaningful error message when
// a model isn't found.
func NewInvalidError(err error) *InvalidError {
_, file, line, _ := runtime.Caller(1)
return &InvalidError{
err: err,
file: file,
line: line,
}
}
// Error fullfills the interface for errors
func (m *InvalidError) Error() string {
return fmt.Sprintf(`Invalid Request: %s:%d : %s`, m.file, m.line, m.err)
}
......@@ -113,7 +113,7 @@ func (d *Db) GetRoleByQuery(query string, arg ...interface{}) ([]models.Role, er
&updatedat,
)
if err != nil {
return roles, NewUnknownError(err)
return roles, NewModelNotFound(err)
}
r.CreatedAt.Set(createdat)
r.UpdatedAt.Set(updatedat)
......
......@@ -21,6 +21,7 @@ var (
`users.name`,
`users.username`,
`users.email`,
`users.token`,
`users.created_at`,
`users.updated_at`,
),
......@@ -36,6 +37,11 @@ var (
prepareFromUser,
q.WhereEq(`users.username`),
)
getUserByEmailQuery = q.QueryBuilder(
prepareSelectUser,
prepareFromUser,
q.WhereEq(`users.email`),
)
getPasswordHash = q.QueryBuilder(
q.Select(`users.password`),
prepareFromUser,
......@@ -49,6 +55,7 @@ var (
`username`,
`password`,
`email`,
`token`,
`created_at`,
`updated_at`,
),
......@@ -68,6 +75,11 @@ var (
q.WhereEq(`user_id`),
q.AndEq(`role_id`),
)
updateUserTokenQuery = q.QueryBuilder(
q.Update(`users`),
q.Set(`token`),
q.WhereEq(`id`),
)
)
// GetUserByID is used by the database resolver to find users
......@@ -76,6 +88,9 @@ func (d *Db) GetUserByID(id uuid.UUID) (*models.User, error) {
if err != nil {
return nil, err
}
if len(u) == 0 {
return nil, NewModelNotFound(errors.New(`User not found`))
}
return &u[0], err
}
......@@ -95,9 +110,24 @@ func (d *Db) GetUsersByID(ids ...uuid.UUID) ([]models.User, error) {
// GetUserByUsername returns a user model containing the username
func (d *Db) GetUserByUsername(username string) (*models.User, error) {
u, err := d.GetUserByQuery(getUserByNameQuery, username)
if err != nil {
return nil, err
}
if len(u) == 0 {
return nil, NewModelNotFound(errors.New(`User not found`))
}
return &u[0], err
}
// GetUserByEmail returns a user model via the email address
func (d *Db) GetUserByEmail(email string) (*models.User, error) {
u, err := d.GetUserByQuery(getUserByEmailQuery, email)
if err != nil {
return nil, err
}
if len(u) == 0 {
return nil, NewModelNotFound(errors.New(`User not found`))
}
return &u[0], err
}
......@@ -124,6 +154,7 @@ func (d *Db) GetUserByQuery(query string, args ...interface{}) ([]models.User, e
&u.Name,
&u.Username,
&u.Email,
&u.Token,
&createdat,
&updatedat,
)
......@@ -169,20 +200,21 @@ func (d *Db) VerifyToken(id, token uuid.UUID) bool {
func (d *Db) Auth(user *models.User, password string) error {
stmt, err := d.Prepare(getPasswordHash)
if err != nil {
return fmt.Errorf("Auth [id:%s]: %s", user.ID.String(), err)
return NewUnknownError(err)
}
defer stmt.Close()
rows, err := stmt.Query(user.ID)
if err != nil {
return fmt.Errorf("Auth [id:%s]: %s", user.ID.String(), err)
return NewUnknownError(err)
}
defer rows.Close()
var phash string
for rows.Next() {
err = rows.Scan(&phash)
}
bcrypt.CompareHashAndPassword([]byte(phash), []byte(password))
return nil
return bcrypt.CompareHashAndPassword([]byte(phash), []byte(password))
}
func (d *Db) checkIfUserExists(username string) error {
......@@ -198,12 +230,10 @@ func (d *Db) checkIfUserExists(username string) error {
// CreateUser a new user in the given database
func (d *Db) CreateUser(u *models.User) error {
log.Println("Creating user:", u.Username)
if u.Password == "" {
return fmt.Errorf("createNewUser [u.Password]: Password not set")
return NewInvalidError(errors.New(`Password not set`))
}
log.Println("Checking if user already exists")
if err := d.checkIfUserExists(u.Username); err != nil {
switch err := err.(type) {
case *ModelNotFound:
......@@ -215,35 +245,32 @@ func (d *Db) CreateUser(u *models.User) error {
}
}
log.Println("Hashing password")
phash, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
u.Password = string(phash)
log.Println(`Inserting user:`, u.Username)
stmt, err := d.Prepare(insertUserQuery)
defer stmt.Close()
if err != nil {
return fmt.Errorf("CreateUser [d.Prepare]: %s", err)
return NewUnknownError(err)
}
u.ID = models.GenUUID()
u.CreatedAt.Set(time.Now())
u.UpdatedAt = u.CreatedAt
if err := u.GenToken(); err != nil {
return fmt.Errorf("CreateUser [u.GenToken]: %s", err)
}
u.GenToken()
_, err = stmt.Exec(
u.ID,
u.Name,
u.Username,
u.Password,
u.Email,
u.Token,
u.CreatedAt.GetTime(),
u.UpdatedAt.GetTime(),
)
if err != nil {
return fmt.Errorf(`CreateUser [stmt.Exec]: %s`, err)
return NewUnknownError(err)
}
return nil
}
......@@ -314,3 +341,20 @@ func (d *Db) AttachUserRole(u *models.User, r *models.Role) error {
}
return nil
}
// UpdateUserToken updates a users token
// TODO: Factor using a standard Update method
func (d *Db) UpdateUserToken(uid uuid.UUID, newtoken string) (*models.User, error) {
stmt, err := d.Prepare(updateUserTokenQuery)
if err != nil {
return nil, NewUnknownError(err)
}
defer stmt.Close()
log.Printf(`Uid: '%s' : NewToken '%s'`, uid.String(), newtoken)
_, err = stmt.Exec(newtoken, uid)
if err != nil {
return nil, NewUnknownError(err)
}
return d.GetUserByID(uid)
}
......@@ -4,15 +4,21 @@ import (
"github.com/graphql-go/graphql"
)
// AuthQuery defines an Auth query
func AuthQuery() *graphql.Field {
// LoginQuery defines the login query
func LoginQuery() *graphql.Field {
return &graphql.Field{
Type: UserGraphqlType,
Args: graphql.FieldConfigArgument{
`token`: &graphql.ArgumentConfig{
`username`: &graphql.ArgumentConfig{
Type: graphql.String,
},
`email`: &graphql.ArgumentConfig{
Type: graphql.String,
},
`password`: &graphql.ArgumentConfig{
Type: graphql.NewNonNull(graphql.String),
},
},
Resolve: resolver.AuthResolver,
Resolve: resolver.LoginResolver,
}
}
......@@ -25,28 +25,37 @@ func ExecuteQuery(query string, schema graphql.Schema) *graphql.Result {
return result
}
// Root query for graphql server
// Root schema for graphql server
type Root struct {
Query *graphql.Object
Query *graphql.Object
Mutation *graphql.Object
}
// NewRoot creates a new root query for graphql
func NewRoot(db *postgres.Db) *Root {
resolver = resolvers.New(db)
root := Root{
Query: graphql.NewObject(
graphql.ObjectConfig{
Name: `Query`,
Fields: graphql.Fields{
`user`: UserQuery(),
`auth`: AuthQuery(),
`login`: LoginQuery(),
`image`: ImageQuery(),
`images`: ImagesQuery(),
`groups`: GroupQuery(),
},
},
),
Mutation: graphql.NewObject(
graphql.ObjectConfig{
Name: `Mutation`,
Fields: graphql.Fields{
`newuser`: AddUserMutation(),
`newtoken`: UpdateUserToken(),
},
},
),
}
return &root
}
......@@ -27,10 +27,10 @@ func init() {
Type: graphql.NewList(types.CustomIDType),
},
`created_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
`updated_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
}
}),
......
......@@ -35,10 +35,10 @@ func init() {
Type: types.CustomIDType,
},
`created_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
`updated_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
}
}),
......@@ -69,6 +69,9 @@ func ImagesQuery() *graphql.Field {
`owner_id`: &graphql.ArgumentConfig{
Type: graphql.String,
},
`all`: &graphql.ArgumentConfig{
Type: graphql.Boolean,
},
},
Resolve: resolver.ImagesResolver,
}
......
......@@ -26,10 +26,10 @@ func init() {
Type: types.CustomIDType,
},
`created_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
`updated_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
}
}),
......
......@@ -26,10 +26,10 @@ func init() {
Type: graphql.NewList(UserGraphqlType),
},
`created_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
`updated_at`: &graphql.Field{
Type: graphql.DateTime,
Type: types.GraphqlDateType,
},
}
}),
......
......@@ -26,7 +26,7 @@ func init() {
Type: graphql.String,
},
`token`: &graphql.Field{
Type: graphql.String,
Type: types.TokenType,
},
`created_at`: &graphql.Field{
Type: types.GraphqlDateType,