invite.go 9.14 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
package invite

import (
	"context"
	"fmt"
	"strings"
	"time"

	"geeks-accelerator/oss/saas-starter-kit/internal/account"
	"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
	"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
	"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
	"geeks-accelerator/oss/saas-starter-kit/internal/user"
	"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
	"github.com/jmoiron/sqlx"
	"github.com/pkg/errors"
	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)

var (
	// ErrInviteExpired occurs when the the reset hash exceeds the expiration.
	ErrInviteExpired = errors.New("Invite expired")

Lee Brown's avatar
Lee Brown committed
24 25 26 27 28
	// ErrNoPendingInvite occurs when the user does not have an entry in user_accounts with status pending.
	ErrNoPendingInvite = errors.New("No pending invite.")

	// ErrUserAccountActive occurs when the user already has an active user_account entry.
	ErrUserAccountActive = errors.New("User already active.")
29 30
)

31 32 33
// SendUserInvites sends emails to the users inviting them to join an account.
func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req SendUserInvitesRequest, secretKey string, now time.Time) ([]string, error) {
	span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites")
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
	defer span.Finish()

	v := webcontext.Validator()

	// Validate the request.
	err := v.StructCtx(ctx, req)
	if err != nil {
		return nil, err
	}

	// Ensure the claims can modify the account specified in the request.
	err = user_account.CanModifyAccount(ctx, claims, dbConn, req.AccountID)
	if err != nil {
		return nil, err
	}

	// Find all the users by email address.
	emailUserIDs := make(map[string]string)
	{
		// Find all users without passing in claims to search all users.
		users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{
Lee Brown's avatar
Lee Brown committed
55 56
			Where: fmt.Sprintf("email in ('%s')",
				strings.Join(req.Emails, "','")),
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
		})
		if err != nil {
			return nil, err
		}

		for _, u := range users {
			emailUserIDs[u.Email] = u.ID
		}
	}

	// Find users that are already active for this account.
	activelUserIDs := make(map[string]bool)
	{
		var args []string
		for _, userID := range emailUserIDs {
			args = append(args, userID)
		}

		userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{
Lee Brown's avatar
Lee Brown committed
76 77 78
			Where: fmt.Sprintf("user_id in ('%s') and status = '%s'",
				strings.Join(args, "','"),
				user_account.UserAccountStatus_Active.String()),
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
		})
		if err != nil {
			return nil, err
		}

		for _, userAcc := range userAccs {
			activelUserIDs[userAcc.UserID] = true
		}
	}

	// Always store the time as UTC.
	now = now.UTC()

	// Postgres truncates times to milliseconds when storing. We and do the same
	// here so the value we return is consistent with what we store.
	now = now.Truncate(time.Millisecond)

	// Create any users that don't already exist.
	for _, email := range req.Emails {
		if uId, ok := emailUserIDs[email]; ok && uId != "" {
			continue
		}

		u, err := user.CreateInvite(ctx, claims, dbConn, user.UserCreateInviteRequest{
			Email: email,
		}, now)
		if err != nil {
			return nil, err
		}

		emailUserIDs[email] = u.ID
	}

	// Loop through all the existing users who either do not have an user_account record or
	// have an existing record, but the status is disabled.
	for _, userID := range emailUserIDs {
		// User already is active, skip.
		if activelUserIDs[userID] {
			continue
		}

		status := user_account.UserAccountStatus_Invited
		_, err = user_account.Create(ctx, claims, dbConn, user_account.UserAccountCreateRequest{
			UserID:    userID,
			AccountID: req.AccountID,
			Roles:     req.Roles,
			Status:    &status,
		}, now)
		if err != nil {
			return nil, err
		}
	}

	if req.TTL.Seconds() == 0 {
		req.TTL = time.Minute * 90
	}

136
	fromUser, err := user.ReadByID(ctx, claims, dbConn, req.UserID)
137 138 139 140
	if err != nil {
		return nil, err
	}

141
	account, err := account.ReadByID(ctx, claims, dbConn, req.AccountID)
142 143 144 145 146 147 148 149 150 151 152 153
	if err != nil {
		return nil, err
	}

	// Load the current IP makings the request.
	var requestIp string
	if vals, _ := webcontext.ContextValues(ctx); vals != nil {
		requestIp = vals.RequestIP
	}

	var inviteHashes []string
	for email, userID := range emailUserIDs {
Lee Brown's avatar
Lee Brown committed
154
		hash, err := NewInviteHash(ctx, secretKey, userID, req.AccountID, requestIp, req.TTL, now)
155
		if err != nil {
Lee Brown's avatar
Lee Brown committed
156
			return nil, err
157 158 159 160 161
		}

		data := map[string]interface{}{
			"FromUser": fromUser.Response(ctx),
			"Account":  account.Response(ctx),
Lee Brown's avatar
Lee Brown committed
162
			"Url":      resetUrl(hash),
163 164 165 166 167 168 169 170 171 172 173
			"Minutes":  req.TTL.Minutes(),
		}

		subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name)

		err = notify.Send(ctx, email, subject, "user_invite", data)
		if err != nil {
			err = errors.WithMessagef(err, "Send invite to %s failed.", email)
			return nil, err
		}

Lee Brown's avatar
Lee Brown committed
174
		inviteHashes = append(inviteHashes, hash)
175 176 177 178 179
	}

	return inviteHashes, nil
}

180
// AcceptInvite updates the user using the provided invite hash.
181
func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
182
	span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite")
183 184 185 186 187 188 189
	defer span.Finish()

	v := webcontext.Validator()

	// Validate the request.
	err := v.StructCtx(ctx, req)
	if err != nil {
190
		return nil, err
191 192
	}

193
	hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
194
	if err != nil {
195
		return nil, err
196 197
	}

198 199
	u, err := user.Read(ctx, auth.Claims{}, dbConn,
		user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
200
	if err != nil {
201
		return nil, err
202 203 204
	}

	if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
205
		err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now)
206
		if err != nil {
207
			return nil, err
Lee Brown's avatar
Lee Brown committed
208 209 210 211 212 213 214 215
		}
	}

	usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{
		UserID:    hash.UserID,
		AccountID: hash.AccountID,
	})
	if err != nil {
216
		return nil, err
Lee Brown's avatar
Lee Brown committed
217 218 219 220 221 222
	}

	// Ensure the entry has the status of invited.
	if usrAcc.Status != user_account.UserAccountStatus_Invited {
		// If the entry is already active
		if usrAcc.Status == user_account.UserAccountStatus_Active {
223
			return usrAcc, errors.WithStack(ErrUserAccountActive)
224
		}
225
		return usrAcc, errors.WithStack(ErrNoPendingInvite)
Lee Brown's avatar
Lee Brown committed
226 227
	}

228 229
	// If the user already has a password set, then just update the user_account entry to status of active.
	// The user will need to login and should not be auto-authenticated.
Lee Brown's avatar
Lee Brown committed
230
	if len(u.PasswordHash) > 0 {
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
		usrAcc.Status = user_account.UserAccountStatus_Active

		err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
			UserID:    usrAcc.UserID,
			AccountID: usrAcc.AccountID,
			Status:    &usrAcc.Status,
		}, now)
		if err != nil {
			return nil, err
		}
	}

	return usrAcc, nil
}

// AcceptInviteUser updates the user using the provided invite hash.
func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
	span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser")
	defer span.Finish()

	v := webcontext.Validator()

	// Validate the request.
	err := v.StructCtx(ctx, req)
	if err != nil {
		return nil, err
	}

	hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
	if err != nil {
		return nil, err
	}

	u, err := user.Read(ctx, auth.Claims{}, dbConn,
		user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
	if err != nil {
		return nil, err
	}

	if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
		err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now)
		if err != nil {
			return nil, err
		}
	}

	usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{
		UserID:    hash.UserID,
		AccountID: hash.AccountID,
	})
	if err != nil {
		return nil, err
	}

	// Ensure the entry has the status of invited.
	if usrAcc.Status != user_account.UserAccountStatus_Invited {
		// If the entry is already active
		if usrAcc.Status == user_account.UserAccountStatus_Active {
			return usrAcc, errors.WithStack(ErrUserAccountActive)
		}
		return nil, errors.WithStack(ErrNoPendingInvite)
292 293
	}

294 295
	// These three calls, user.Update,  user.UpdatePassword, and user_account.Update
	// should probably be in a transaction!
296 297
	err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{
		ID:        hash.UserID,
Lee Brown's avatar
Lee Brown committed
298
		Email:     &req.Email,
299 300 301 302 303
		FirstName: &req.FirstName,
		LastName:  &req.LastName,
		Timezone:  req.Timezone,
	}, now)
	if err != nil {
304
		return nil, err
305 306 307 308 309 310 311 312
	}

	err = user.UpdatePassword(ctx, auth.Claims{}, dbConn, user.UserUpdatePasswordRequest{
		ID:              hash.UserID,
		Password:        req.Password,
		PasswordConfirm: req.PasswordConfirm,
	}, now)
	if err != nil {
313
		return nil, err
Lee Brown's avatar
Lee Brown committed
314 315
	}

316
	usrAcc.Status = user_account.UserAccountStatus_Active
Lee Brown's avatar
Lee Brown committed
317 318 319
	err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
		UserID:    usrAcc.UserID,
		AccountID: usrAcc.AccountID,
320
		Status:    &usrAcc.Status,
Lee Brown's avatar
Lee Brown committed
321 322
	}, now)
	if err != nil {
323
		return nil, err
324 325
	}

326
	return usrAcc, nil
327
}