package models import ( "crypto/sha256" "database/sql" "encoding/base64" "fmt" "strings" "time" ) const ( resetTokenBytes int = 32 ) func resetToken() (string, error) { return RandString(resetTokenBytes) } func hashToken(token string) string { tokenHash := sha256.Sum256([]byte(token)) return base64.URLEncoding.EncodeToString(tokenHash[:]) } type PasswordReset struct { ID int UserID int // Token is only set when a PasswordReset is being created. Token string TokenHash string ExpiresAt time.Time } type PasswordResetService struct { DB *sql.DB BytesPerToken int Duration time.Duration } func (service *PasswordResetService) delete(id int) error { _, err := service.DB.Exec(` DELETE FROM password_resets WHERE id = $1;`, id) if err != nil { return fmt.Errorf("delete: %w", err) } return nil } func (service *PasswordResetService) Create(email string) (*PasswordReset, error) { // Verify we have a valid email address for user email = strings.ToLower(email) var UserID int row := service.DB.QueryRow(` SELECT id FROM users WHERE email = $1; `, email) err := row.Scan(&UserID) if err != nil { return nil, fmt.Errorf("Create: %w", err) } token, err := resetToken() if err != nil { return nil, fmt.Errorf("Create: %w", err) } duration := service.Duration if duration == 0 { duration = time.Hour } pwReset := PasswordReset{ UserID: UserID, Token: token, TokenHash: hashToken(token), ExpiresAt: time.Now().Add(duration), } row = service.DB.QueryRow(` INSERT INTO password_resets (user_id, token_hash, expires_at) VALUES ($1, $2, $3) ON CONFLICT (user_id) DO UPDATE SET token_hash = $2, expires_at = $3 RETURNING id; `, pwReset.UserID, pwReset.TokenHash, pwReset.ExpiresAt) err = row.Scan(&pwReset.ID) if err != nil { return nil, fmt.Errorf("Create: %w", err) } return &pwReset, nil } // We are going to consume a token and return the user associated with it, or return an error if the token wasn't valid for any reason. func (service *PasswordResetService) Consume(token string) (*User, error) { var pwReset PasswordReset var user User pwReset.TokenHash = hashToken(token) row := service.DB.QueryRow(` SELECT password_resets.id, password_resets.expires_at, users.id, users.email, users.password_hash FROM password_resets JOIN users ON users.id = password_resets.user_id WHERE password_resets.token_hash = $1; `, pwReset.TokenHash) err := row.Scan(&pwReset.ID, &pwReset.ExpiresAt, &user.ID, &user.Email, &user.PasswordHash) if err != nil { return nil, fmt.Errorf("Consume: %w", err) } if time.Now().After(pwReset.ExpiresAt) { return nil, fmt.Errorf("Invalid token") } err = service.delete(pwReset.ID) if err != nil { return nil, fmt.Errorf("consume: %w", err) } return &user, nil }