Add password reset service

This commit is contained in:
Lucas Schumacher 2024-09-12 07:33:10 -04:00
parent 56ce9fa2f8
commit 32f10ce40c
11 changed files with 539 additions and 38 deletions

View File

@ -3,6 +3,8 @@ package controllers
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"time"
userctx "git.kealoha.me/lks/lenslocked/context" userctx "git.kealoha.me/lks/lenslocked/context"
"git.kealoha.me/lks/lenslocked/models" "git.kealoha.me/lks/lenslocked/models"
@ -14,9 +16,14 @@ type Users struct {
Templates struct { Templates struct {
Signup Template Signup Template
Signin Template Signin Template
ForgotPass Template
ResetUrlSent Template
ResetPass Template
} }
UserService *models.UserService UserService *models.UserService
SessionService *models.SessionService SessionService *models.SessionService
PassResetService *models.PasswordResetService
EmailService *models.EmailService
} }
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
@ -114,6 +121,87 @@ func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/signin", http.StatusFound) http.Redirect(w, r, "/signin", http.StatusFound)
} }
func (u Users) GetForgotPassword(w http.ResponseWriter, r *http.Request) {
var data struct {
Email string
}
data.Email = r.FormValue("email")
u.Templates.ForgotPass.Execute(w, r, data)
}
func (u Users) PostForgotPassword(w http.ResponseWriter, r *http.Request) {
var data struct {
Email string
}
data.Email = r.FormValue("email")
pwReset, err := u.PassResetService.Create(data.Email)
if err != nil {
fmt.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
vals := url.Values{
"token": {pwReset.Token},
}
// TODO: Make the URL here configurable and use https
resetURL := "http://" + r.Host + "/reset-pw?" + vals.Encode()
fmt.Println(resetURL)
err = u.EmailService.SendPasswordReset(data.Email, resetURL)
if err != nil {
fmt.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
u.Templates.ResetUrlSent.Execute(w, r, data)
}
func (u Users) GetResetPass(w http.ResponseWriter, r *http.Request) {
var data struct {
Token string
}
data.Token = r.FormValue("token")
u.Templates.ResetPass.Execute(w, r, data)
}
func (u Users) PostResetPass(w http.ResponseWriter, r *http.Request) {
var data struct {
Token, Password string
}
data.Token = r.FormValue("token")
data.Password = r.FormValue("password")
user, err := u.PassResetService.Consume(data.Token)
if err != nil {
fmt.Println(err)
// TODO: Distinguish between server errors and invalid token errors.
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
err = u.UserService.UpdatePassword(user.ID, data.Password)
if err != nil {
fmt.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Sign the user in now that they have reset their password.
// Any errors from this point onward should redirect to the sign in page.
session, err := u.SessionService.Create(user.ID)
if err != nil {
fmt.Println(err)
http.Redirect(w, r, "/signin", http.StatusFound)
return
}
//setCookie(w, CookieSession, session.Token)
cookie := http.Cookie{
Name: "session",
Value: session.Token,
Path: "/",
HttpOnly: true,
}
http.SetCookie(w, &cookie)
http.Redirect(w, r, "/users/me", http.StatusFound)
}
func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
user := userctx.User(r.Context()) user := userctx.User(r.Context())
if user == nil { if user == nil {
@ -123,18 +211,32 @@ func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Current user: %s\n", user.Email) fmt.Fprintf(w, "Current user: %s\n", user.Email)
} }
func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users { func WithTemplates(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService, signup, signin, forgotPass, resetUrlSent, resetPass Template) Users {
u := Users{} u := Users{}
u.Templates.Signup = signup u.Templates.Signup = signup
u.Templates.Signin = signin u.Templates.Signin = signin
u.Templates.ForgotPass = forgotPass
u.Templates.ResetUrlSent = resetUrlSent
u.Templates.ResetPass = resetPass
u.UserService = user_service u.UserService = user_service
u.SessionService = session_service u.SessionService = session_service
u.EmailService = email_service
u.PassResetService = &models.PasswordResetService{
DB: u.UserService.DB,
Duration: time.Hour / 2,
}
return u return u
} }
func Default(user_service *models.UserService, session_service *models.SessionService) Users { func Default(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService) Users {
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml")) signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml")) signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
pwReset_tpl := views.Must(views.FromFS(templates.FS, "pwReset.gohtml", "tailwind.gohtml"))
pwResetSent_tpl := views.Must(views.FromFS(templates.FS, "pwResetSent.gohtml", "tailwind.gohtml"))
resetPass_tpl := views.Must(views.FromFS(templates.FS, "pwChange.gohtml", "tailwind.gohtml"))
err := signup_tpl.TestTemplate(nil) err := signup_tpl.TestTemplate(nil)
if err != nil { if err != nil {
@ -144,6 +246,14 @@ func Default(user_service *models.UserService, session_service *models.SessionSe
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = pwReset_tpl.TestTemplate(nil)
return WithTemplates(user_service, session_service, signup_tpl, signin_tpl) if err != nil {
panic(err)
}
err = pwResetSent_tpl.TestTemplate(nil)
if err != nil {
panic(err)
}
return WithTemplates(user_service, session_service, email_service, signup_tpl, signin_tpl, pwReset_tpl, pwResetSent_tpl, resetPass_tpl)
} }

97
main.go
View File

@ -23,14 +23,68 @@ import (
const DEBUG bool = true const DEBUG bool = true
type config struct {
Postgres string
Email struct {
Host string
Port int
Username, Pass, Sender string
}
Csrf struct {
Key []byte
Secure bool
}
Server struct {
Address string
}
}
func loadConfig() (config, error) {
var cfg config
cfg.Csrf.Secure = !DEBUG
err := godotenv.Load()
if err != nil {
fmt.Println("Warning: Could not load a .env file")
}
cfg.Csrf.Key = []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
if len(cfg.Csrf.Key) < 32 {
return cfg, fmt.Errorf("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.")
}
cfg.Postgres = os.Getenv("LENSLOCKED_DB_STRING")
cfg.Email.Host = os.Getenv("LENSLOCKED_EMAIL_HOST")
cfg.Email.Username = os.Getenv("LENSLOCKED_EMAIL_USERNAME")
cfg.Email.Pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD")
cfg.Email.Sender = os.Getenv("LENSLOCKED_EMAIL_FROM")
cfg.Email.Port, err = strconv.Atoi(os.Getenv("LENSLOCKED_EMAIL_PORT"))
if err != nil {
fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587")
cfg.Email.Port = 587
}
cfg.Server.Address = os.Getenv("LENSLOCKED_ADDRESS")
if cfg.Server.Address == "" {
if DEBUG {
cfg.Server.Address = ":3000"
} else {
return cfg, fmt.Errorf("No server address set\nPlease set the LENSLOCKED_ADDRESS env var to the servers address")
}
}
return cfg, nil
}
func notFoundHandler(w http.ResponseWriter, r *http.Request) { func notFoundHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf8") w.Header().Set("Content-Type", "text/html; charset=utf8")
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, "404 page not found") fmt.Fprint(w, "404 page not found")
} }
func ConnectDB() *sql.DB { func ConnectDB(dbstr string) *sql.DB {
db, err := sql.Open("pgx", os.Getenv("LENSLOCKED_DB_STRING")) db, err := sql.Open("pgx", dbstr)
if err != nil { if err != nil {
panic(fmt.Sprint("Error connecting to database: %w", err)) panic(fmt.Sprint("Error connecting to database: %w", err))
} }
@ -55,29 +109,11 @@ func MigrateDB(db *sql.DB, subfs fs.FS) error {
} }
func main() { func main() {
err := godotenv.Load() cfg, err := loadConfig()
if err != nil { if err != nil {
fmt.Println("Warning: Could not load .env file") panic(err)
} }
db := ConnectDB(cfg.Postgres)
var (
email_host = os.Getenv("LENSLOCKED_EMAIL_HOST")
email_port_str = os.Getenv("LENSLOCKED_EMAIL_PORT")
email_username = os.Getenv("LENSLOCKED_EMAIL_USERNAME")
email_pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD")
email_sender = os.Getenv("LENSLOCKED_EMAIL_FROM")
csrfKey = []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
)
if len(csrfKey) < 32 {
panic("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.")
}
email_port, err := strconv.Atoi(email_port_str)
if err != nil {
fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587")
email_port = 587
}
db := ConnectDB()
defer db.Close() defer db.Close()
err = MigrateDB(db, migrations.FS) err = MigrateDB(db, migrations.FS)
if err != nil { if err != nil {
@ -86,15 +122,15 @@ func main() {
userService := models.UserService{DB: db} userService := models.UserService{DB: db}
sessionService := models.SessionService{DB: db} sessionService := models.SessionService{DB: db}
_ = models.NewEmailService(email_host, email_port, email_username, email_pass, email_sender) emailService := models.NewEmailService(cfg.Email.Host, cfg.Email.Port, cfg.Email.Username, cfg.Email.Pass, cfg.Email.Sender)
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService) var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService, emailService)
umw := userctx.UserMiddleware{SS: &sessionService} umw := userctx.UserMiddleware{SS: &sessionService}
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.Logger) r.Use(middleware.Logger)
r.Use(csrf.Protect(csrfKey, csrf.Secure(!DEBUG))) r.Use(csrf.Protect(cfg.Csrf.Key, csrf.Secure(cfg.Csrf.Secure)))
r.Use(umw.SetUser) r.Use(umw.SetUser)
r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml")) r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml"))
@ -106,11 +142,14 @@ func main() {
r.Get("/signin", usersCtrlr.GetSignin) r.Get("/signin", usersCtrlr.GetSignin)
r.Post("/signin", usersCtrlr.PostSignin) r.Post("/signin", usersCtrlr.PostSignin)
r.Post("/signout", usersCtrlr.GetSignout) r.Post("/signout", usersCtrlr.GetSignout)
r.Get("/forgot-pw", usersCtrlr.GetForgotPassword)
r.Post("/forgot-pw", usersCtrlr.PostForgotPassword)
r.Get("/reset-pw", usersCtrlr.GetResetPass)
r.Post("/reset-pw", usersCtrlr.PostResetPass)
//r.Get("/user", usersCtrlr.CurrentUser)
r.Get("/user", umw.RequireUserfn(usersCtrlr.CurrentUser)) r.Get("/user", umw.RequireUserfn(usersCtrlr.CurrentUser))
r.NotFound(notFoundHandler) r.NotFound(notFoundHandler)
fmt.Println("Starting the server on :3000...") fmt.Printf("Starting the server on %s...\n", cfg.Server.Address)
http.ListenAndServe(":3000", r) http.ListenAndServe(cfg.Server.Address, r)
} }

View File

@ -0,0 +1,14 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE password_resets (
id SERIAL PRIMARY KEY,
user_id INT UNIQUE REFERENCES users (id) ON DELETE CASCADE,
token_hash TEXT UNIQUE NOT NULL,
expires_at TIMESTAMPTZ NOT NULL
);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TABLE password_resets;
-- +goose StatementEnd

View File

@ -49,3 +49,17 @@ func (es *EmailService) Send(email Email) error {
} }
return nil return nil
} }
func (es *EmailService) SendPasswordReset(to, resetURL string) error {
email := Email{
Subject: "Reset your password",
To: to,
Text: "To reset your password, please visit the following link: " + resetURL,
Html: `<p>To reset your password, please visit the following link: <a href="` + resetURL + `">` + resetURL + `</a></p>`,
}
err := es.Send(email)
if err != nil {
return fmt.Errorf("forgot password email: %w", err)
}
return nil
}

118
models/password_reset.go Normal file
View File

@ -0,0 +1,118 @@
package models
import (
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt"
"strings"
"time"
)
const (
resetTokenBytes int = 32
)
func resetToken() (string, error) {
return RandString(resetTokenBytes)
}
func hashToken(token string) string {
tokenHash := sha256.Sum256([]byte(token))
return base64.URLEncoding.EncodeToString(tokenHash[:])
}
type PasswordReset struct {
ID int
UserID int
// Token is only set when a PasswordReset is being created.
Token string
TokenHash string
ExpiresAt time.Time
}
type PasswordResetService struct {
DB *sql.DB
BytesPerToken int
Duration time.Duration
}
func (service *PasswordResetService) delete(id int) error {
_, err := service.DB.Exec(`
DELETE FROM password_resets
WHERE id = $1;`, id)
if err != nil {
return fmt.Errorf("delete: %w", err)
}
return nil
}
func (service *PasswordResetService) Create(email string) (*PasswordReset, error) {
// Verify we have a valid email address for user
email = strings.ToLower(email)
var UserID int
row := service.DB.QueryRow(`
SELECT id FROM users WHERE email = $1;
`, email)
err := row.Scan(&UserID)
if err != nil {
return nil, fmt.Errorf("Create: %w", err)
}
token, err := resetToken()
if err != nil {
return nil, fmt.Errorf("Create: %w", err)
}
duration := service.Duration
if duration == 0 {
duration = time.Hour
}
pwReset := PasswordReset{
UserID: UserID,
Token: token,
TokenHash: hashToken(token),
ExpiresAt: time.Now().Add(duration),
}
row = service.DB.QueryRow(`
INSERT INTO password_resets (user_id, token_hash, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (user_id) DO UPDATE
SET token_hash = $2, expires_at = $3
RETURNING id;
`, pwReset.UserID, pwReset.TokenHash, pwReset.ExpiresAt)
err = row.Scan(&pwReset.ID)
if err != nil {
return nil, fmt.Errorf("Create: %w", err)
}
return &pwReset, nil
}
// We are going to consume a token and return the user associated with it, or return an error if the token wasn't valid for any reason.
func (service *PasswordResetService) Consume(token string) (*User, error) {
var pwReset PasswordReset
var user User
pwReset.TokenHash = hashToken(token)
row := service.DB.QueryRow(`
SELECT password_resets.id, password_resets.expires_at, users.id, users.email, users.password_hash
FROM password_resets JOIN users ON users.id = password_resets.user_id
WHERE password_resets.token_hash = $1;
`, pwReset.TokenHash)
err := row.Scan(&pwReset.ID, &pwReset.ExpiresAt, &user.ID, &user.Email, &user.PasswordHash)
if err != nil {
return nil, fmt.Errorf("Consume: %w", err)
}
if time.Now().After(pwReset.ExpiresAt) {
return nil, fmt.Errorf("Invalid token")
}
err = service.delete(pwReset.ID)
if err != nil {
return nil, fmt.Errorf("consume: %w", err)
}
return &user, nil
}

View File

@ -42,6 +42,22 @@ func (us *UserService) Create(email, password string) (*User, error) {
return &user, nil return &user, nil
} }
func (us *UserService) UpdatePassword(userID int, password string) error {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("update password: %w", err)
}
passwordHash := string(hashedBytes)
_, err = us.DB.Exec(`
UPDATE users
SET password_hash = $2
WHERE id = $1;`, userID, passwordHash)
if err != nil {
return fmt.Errorf("update password: %w", err)
}
return nil
}
func (us UserService) Authenticate(email, password string) (*User, error) { func (us UserService) Authenticate(email, password string) (*User, error) {
user := User{ user := User{
Email: strings.ToLower(email), Email: strings.ToLower(email),

98
templates/pwChange.gohtml Normal file
View File

@ -0,0 +1,98 @@
<!doctype html>
<html>
{{template "head" .}}
<body class="min-h-screen bg-gray-100">
{{template "header".}}
<main class="px-6">
<div class="py-12 flex justify-center">
<div class="px-8 py-8 bg-white rounded shadow">
<h1 class="pt-4 pb-8 text-center text-3xl font-bold text-gray-900">
Reset your password
</h1>
<form action="/reset-pw" method="post">
<div class="hidden">
{{csrfField}}
</div>
<div class="py-2">
<label for="password" class="text-sm font-semibold text-gray-800"
>New password</label
>
<input
name="password"
id="password"
type="password"
placeholder="Password"
required
class="
w-full
px-3
py-2
border border-gray-300
placeholder-gray-500
text-gray-800
rounded
"
autofocus
/>
</div>
{{if .Token}}
<div class="hidden">
<input type="hidden" id="token" name="token" value="{{.Token}}" />
</div>
{{else}}
<div class="py-2">
<label for="token" class="text-sm font-semibold text-gray-800"
>Password Reset Token</label
>
<input
name="token"
id="token"
type="text"
placeholder="Check your email"
required
class="
w-full
px-3
py-2
border border-gray-300
placeholder-gray-500
text-gray-800
rounded
"
/>
</div>
{{end}}
<div class="py-4">
<button
type="submit"
class="
w-full
py-4
px-2
bg-indigo-600
hover:bg-indigo-700
text-white
rounded
font-bold
text-lg
"
>
Update password
</button>
</div>
<div class="py-2 w-full flex justify-between">
<p class="text-xs text-gray-500">
<a href="/signup" class="underline">Sign up</a>
</p>
<p class="text-xs text-gray-500">
<a href="/signin" class="underline">Sign in</a>
</p>
</div>
</form>
</div>
</div>
</main>
{{template "footer" .}}
</body>
</html>

74
templates/pwReset.gohtml Normal file
View File

@ -0,0 +1,74 @@
<!doctype html>
<html>
{{template "head" .}}
<body class="min-h-screen bg-gray-100">
{{template "header".}}
<main class="px-6">
<div class="py-12 flex justify-center">
<div class="px-8 py-8 bg-white rounded shadow">
<h1 class="pt-4 pb-8 text-center text-3xl font-bold text-gray-900">
Forgot your password?
</h1>
<p class="text-sm text-gray-600 pb-4">No problem. Enter your email address and we'll send you a link to reset your password.</p>
<form action="/forgot-pw" method="post">
<div class="hidden">
{{csrfField}}
</div>
<div class="py-2">
<label for="email" class="text-sm font-semibold text-gray-800"
>Email Address</label
>
<input
name="email"
id="email"
type="email"
placeholder="Email address"
required
autocomplete="email"
class="
w-full
px-3
py-2
border border-gray-300
placeholder-gray-500
text-gray-800
rounded
"
value="{{.Email}}"
autofocus
/>
</div>
<div class="py-4">
<button
type="submit"
class="
w-full
py-4
px-2
bg-indigo-600
hover:bg-indigo-700
text-white
rounded
font-bold
text-lg
"
>
Reset password
</button>
</div>
<div class="py-2 w-full flex justify-between">
<p class="text-xs text-gray-500">
Need an account?
<a href="/signup" class="underline">Sign up</a>
</p>
<p class="text-xs text-gray-500">
<a href="/signin" class="underline">Remember your password?</a>
</p>
</div>
</form>
</div>
</div>
</main>
{{template "footer" .}}
</body>
</html>

View File

@ -0,0 +1,18 @@
<!doctype html>
<html>
{{template "head" .}}
<body class="min-h-screen bg-gray-100">
{{template "header".}}
<main class="px-6">
<div class="py-12 flex justify-center">
<div class="px-8 py-8 bg-white rounded shadow">
<h1 class="pt-4 pb-8 text-center text-3xl font-bold text-gray-900">
Check your email
</h1>
<p class="text-sm text-gray-600 pb-4">An email has been sent to the email address {{.Email}} with instructions to reset your password.</p>
</div>
</div>
</main>
{{template "footer" .}}
</body>
</html>

View File

@ -54,7 +54,7 @@
<a href="/signup" class="underline">Sign up</a> <a href="/signup" class="underline">Sign up</a>
</p> </p>
<p class="text-xs text-gray-500"> <p class="text-xs text-gray-500">
<a href="/reset-pw" class="underline">Forgot your password?</a> <a href="/forgot-pw" class="underline">Forgot your password?</a>
</p> </p>
</div> </div>
</form> </form>

View File

@ -41,7 +41,7 @@
<a href="/signin" class="underline">Sign in</a> <a href="/signin" class="underline">Sign in</a>
</p> </p>
<p class="text-xs text-gray-500"> <p class="text-xs text-gray-500">
<a href="/reset-pw" class="underline">Forgot your password?</a> <a href="/forgot-pw" class="underline">Forgot your password?</a>
</p> </p>
</div> </div>
</form> </form>