119 lines
2.8 KiB
Go
119 lines
2.8 KiB
Go
package models
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
resetTokenBytes int = 32
|
|
)
|
|
|
|
func resetToken() (string, error) {
|
|
return RandString(resetTokenBytes)
|
|
}
|
|
func hashToken(token string) string {
|
|
tokenHash := sha256.Sum256([]byte(token))
|
|
return base64.URLEncoding.EncodeToString(tokenHash[:])
|
|
}
|
|
|
|
type PasswordReset struct {
|
|
ID int
|
|
UserID int
|
|
// Token is only set when a PasswordReset is being created.
|
|
Token string
|
|
TokenHash string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type PasswordResetService struct {
|
|
DB *sql.DB
|
|
BytesPerToken int
|
|
Duration time.Duration
|
|
}
|
|
|
|
func (service *PasswordResetService) delete(id int) error {
|
|
_, err := service.DB.Exec(`
|
|
DELETE FROM password_resets
|
|
WHERE id = $1;`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (service *PasswordResetService) Create(email string) (*PasswordReset, error) {
|
|
// Verify we have a valid email address for user
|
|
email = strings.ToLower(email)
|
|
var UserID int
|
|
row := service.DB.QueryRow(`
|
|
SELECT id FROM users WHERE email = $1;
|
|
`, email)
|
|
err := row.Scan(&UserID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Create: %w", err)
|
|
}
|
|
|
|
token, err := resetToken()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Create: %w", err)
|
|
}
|
|
|
|
duration := service.Duration
|
|
if duration == 0 {
|
|
duration = time.Hour
|
|
}
|
|
|
|
pwReset := PasswordReset{
|
|
UserID: UserID,
|
|
Token: token,
|
|
TokenHash: hashToken(token),
|
|
ExpiresAt: time.Now().Add(duration),
|
|
}
|
|
|
|
row = service.DB.QueryRow(`
|
|
INSERT INTO password_resets (user_id, token_hash, expires_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (user_id) DO UPDATE
|
|
SET token_hash = $2, expires_at = $3
|
|
RETURNING id;
|
|
`, pwReset.UserID, pwReset.TokenHash, pwReset.ExpiresAt)
|
|
err = row.Scan(&pwReset.ID)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Create: %w", err)
|
|
}
|
|
return &pwReset, nil
|
|
}
|
|
|
|
// We are going to consume a token and return the user associated with it, or return an error if the token wasn't valid for any reason.
|
|
func (service *PasswordResetService) Consume(token string) (*User, error) {
|
|
var pwReset PasswordReset
|
|
var user User
|
|
pwReset.TokenHash = hashToken(token)
|
|
|
|
row := service.DB.QueryRow(`
|
|
SELECT password_resets.id, password_resets.expires_at, users.id, users.email, users.password_hash
|
|
FROM password_resets JOIN users ON users.id = password_resets.user_id
|
|
WHERE password_resets.token_hash = $1;
|
|
`, pwReset.TokenHash)
|
|
err := row.Scan(&pwReset.ID, &pwReset.ExpiresAt, &user.ID, &user.Email, &user.PasswordHash)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Consume: %w", err)
|
|
}
|
|
|
|
if time.Now().After(pwReset.ExpiresAt) {
|
|
return nil, fmt.Errorf("Invalid token")
|
|
}
|
|
|
|
err = service.delete(pwReset.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("consume: %w", err)
|
|
}
|
|
return &user, nil
|
|
}
|