From 32f10ce40c41d67947236db7a7a6f92ef78a03c9 Mon Sep 17 00:00:00 2001
From: Lucas Schumacher
Date: Thu, 12 Sep 2024 07:33:10 -0400
Subject: [PATCH] Add password reset service
---
controllers/users.go | 124 ++++++++++++++++++++++++++--
main.go | 97 +++++++++++++++-------
migrations/00003_password_reset.sql | 14 ++++
models/email.go | 14 ++++
models/password_reset.go | 118 ++++++++++++++++++++++++++
models/user.go | 16 ++++
templates/pwChange.gohtml | 98 ++++++++++++++++++++++
templates/pwReset.gohtml | 74 +++++++++++++++++
templates/pwResetSent.gohtml | 18 ++++
templates/signin.gohtml | 2 +-
templates/signup.gohtml | 2 +-
11 files changed, 539 insertions(+), 38 deletions(-)
create mode 100644 migrations/00003_password_reset.sql
create mode 100644 models/password_reset.go
create mode 100644 templates/pwChange.gohtml
create mode 100644 templates/pwReset.gohtml
create mode 100644 templates/pwResetSent.gohtml
diff --git a/controllers/users.go b/controllers/users.go
index 7e9314b..0d4122e 100644
--- a/controllers/users.go
+++ b/controllers/users.go
@@ -3,6 +3,8 @@ package controllers
import (
"fmt"
"net/http"
+ "net/url"
+ "time"
userctx "git.kealoha.me/lks/lenslocked/context"
"git.kealoha.me/lks/lenslocked/models"
@@ -12,11 +14,16 @@ import (
type Users struct {
Templates struct {
- Signup Template
- Signin Template
+ Signup Template
+ Signin Template
+ ForgotPass Template
+ ResetUrlSent Template
+ ResetPass Template
}
- UserService *models.UserService
- SessionService *models.SessionService
+ UserService *models.UserService
+ SessionService *models.SessionService
+ PassResetService *models.PasswordResetService
+ EmailService *models.EmailService
}
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
@@ -114,6 +121,87 @@ func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/signin", http.StatusFound)
}
+func (u Users) GetForgotPassword(w http.ResponseWriter, r *http.Request) {
+ var data struct {
+ Email string
+ }
+ data.Email = r.FormValue("email")
+ u.Templates.ForgotPass.Execute(w, r, data)
+}
+
+func (u Users) PostForgotPassword(w http.ResponseWriter, r *http.Request) {
+ var data struct {
+ Email string
+ }
+ data.Email = r.FormValue("email")
+ pwReset, err := u.PassResetService.Create(data.Email)
+ if err != nil {
+ fmt.Println(err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ vals := url.Values{
+ "token": {pwReset.Token},
+ }
+ // TODO: Make the URL here configurable and use https
+ resetURL := "http://" + r.Host + "/reset-pw?" + vals.Encode()
+ fmt.Println(resetURL)
+ err = u.EmailService.SendPasswordReset(data.Email, resetURL)
+ if err != nil {
+ fmt.Println(err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ u.Templates.ResetUrlSent.Execute(w, r, data)
+}
+
+func (u Users) GetResetPass(w http.ResponseWriter, r *http.Request) {
+ var data struct {
+ Token string
+ }
+ data.Token = r.FormValue("token")
+ u.Templates.ResetPass.Execute(w, r, data)
+}
+func (u Users) PostResetPass(w http.ResponseWriter, r *http.Request) {
+ var data struct {
+ Token, Password string
+ }
+ data.Token = r.FormValue("token")
+ data.Password = r.FormValue("password")
+
+ user, err := u.PassResetService.Consume(data.Token)
+ if err != nil {
+ fmt.Println(err)
+ // TODO: Distinguish between server errors and invalid token errors.
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ err = u.UserService.UpdatePassword(user.ID, data.Password)
+ if err != nil {
+ fmt.Println(err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+
+ // Sign the user in now that they have reset their password.
+ // Any errors from this point onward should redirect to the sign in page.
+ session, err := u.SessionService.Create(user.ID)
+ if err != nil {
+ fmt.Println(err)
+ http.Redirect(w, r, "/signin", http.StatusFound)
+ return
+ }
+ //setCookie(w, CookieSession, session.Token)
+ cookie := http.Cookie{
+ Name: "session",
+ Value: session.Token,
+ Path: "/",
+ HttpOnly: true,
+ }
+ http.SetCookie(w, &cookie)
+ http.Redirect(w, r, "/users/me", http.StatusFound)
+}
+
func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
user := userctx.User(r.Context())
if user == nil {
@@ -123,18 +211,32 @@ func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Current user: %s\n", user.Email)
}
-func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users {
+func WithTemplates(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService, signup, signin, forgotPass, resetUrlSent, resetPass Template) Users {
u := Users{}
+
u.Templates.Signup = signup
u.Templates.Signin = signin
+ u.Templates.ForgotPass = forgotPass
+ u.Templates.ResetUrlSent = resetUrlSent
+ u.Templates.ResetPass = resetPass
+
u.UserService = user_service
u.SessionService = session_service
+ u.EmailService = email_service
+ u.PassResetService = &models.PasswordResetService{
+ DB: u.UserService.DB,
+ Duration: time.Hour / 2,
+ }
+
return u
}
-func Default(user_service *models.UserService, session_service *models.SessionService) Users {
+func Default(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService) Users {
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
+ pwReset_tpl := views.Must(views.FromFS(templates.FS, "pwReset.gohtml", "tailwind.gohtml"))
+ pwResetSent_tpl := views.Must(views.FromFS(templates.FS, "pwResetSent.gohtml", "tailwind.gohtml"))
+ resetPass_tpl := views.Must(views.FromFS(templates.FS, "pwChange.gohtml", "tailwind.gohtml"))
err := signup_tpl.TestTemplate(nil)
if err != nil {
@@ -144,6 +246,14 @@ func Default(user_service *models.UserService, session_service *models.SessionSe
if err != nil {
panic(err)
}
+ err = pwReset_tpl.TestTemplate(nil)
+ if err != nil {
+ panic(err)
+ }
+ err = pwResetSent_tpl.TestTemplate(nil)
+ if err != nil {
+ panic(err)
+ }
- return WithTemplates(user_service, session_service, signup_tpl, signin_tpl)
+ return WithTemplates(user_service, session_service, email_service, signup_tpl, signin_tpl, pwReset_tpl, pwResetSent_tpl, resetPass_tpl)
}
diff --git a/main.go b/main.go
index c104834..96f9320 100644
--- a/main.go
+++ b/main.go
@@ -23,14 +23,68 @@ import (
const DEBUG bool = true
+type config struct {
+ Postgres string
+ Email struct {
+ Host string
+ Port int
+ Username, Pass, Sender string
+ }
+ Csrf struct {
+ Key []byte
+ Secure bool
+ }
+ Server struct {
+ Address string
+ }
+}
+
+func loadConfig() (config, error) {
+ var cfg config
+ cfg.Csrf.Secure = !DEBUG
+
+ err := godotenv.Load()
+ if err != nil {
+ fmt.Println("Warning: Could not load a .env file")
+ }
+
+ cfg.Csrf.Key = []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
+ if len(cfg.Csrf.Key) < 32 {
+ return cfg, fmt.Errorf("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.")
+ }
+
+ cfg.Postgres = os.Getenv("LENSLOCKED_DB_STRING")
+
+ cfg.Email.Host = os.Getenv("LENSLOCKED_EMAIL_HOST")
+ cfg.Email.Username = os.Getenv("LENSLOCKED_EMAIL_USERNAME")
+ cfg.Email.Pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD")
+ cfg.Email.Sender = os.Getenv("LENSLOCKED_EMAIL_FROM")
+ cfg.Email.Port, err = strconv.Atoi(os.Getenv("LENSLOCKED_EMAIL_PORT"))
+ if err != nil {
+ fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587")
+ cfg.Email.Port = 587
+ }
+
+ cfg.Server.Address = os.Getenv("LENSLOCKED_ADDRESS")
+ if cfg.Server.Address == "" {
+ if DEBUG {
+ cfg.Server.Address = ":3000"
+ } else {
+ return cfg, fmt.Errorf("No server address set\nPlease set the LENSLOCKED_ADDRESS env var to the servers address")
+ }
+ }
+
+ return cfg, nil
+}
+
func notFoundHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf8")
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, "404 page not found")
}
-func ConnectDB() *sql.DB {
- db, err := sql.Open("pgx", os.Getenv("LENSLOCKED_DB_STRING"))
+func ConnectDB(dbstr string) *sql.DB {
+ db, err := sql.Open("pgx", dbstr)
if err != nil {
panic(fmt.Sprint("Error connecting to database: %w", err))
}
@@ -55,29 +109,11 @@ func MigrateDB(db *sql.DB, subfs fs.FS) error {
}
func main() {
- err := godotenv.Load()
+ cfg, err := loadConfig()
if err != nil {
- fmt.Println("Warning: Could not load .env file")
+ panic(err)
}
-
- var (
- email_host = os.Getenv("LENSLOCKED_EMAIL_HOST")
- email_port_str = os.Getenv("LENSLOCKED_EMAIL_PORT")
- email_username = os.Getenv("LENSLOCKED_EMAIL_USERNAME")
- email_pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD")
- email_sender = os.Getenv("LENSLOCKED_EMAIL_FROM")
- csrfKey = []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
- )
- if len(csrfKey) < 32 {
- panic("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.")
- }
- email_port, err := strconv.Atoi(email_port_str)
- if err != nil {
- fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587")
- email_port = 587
- }
-
- db := ConnectDB()
+ db := ConnectDB(cfg.Postgres)
defer db.Close()
err = MigrateDB(db, migrations.FS)
if err != nil {
@@ -86,15 +122,15 @@ func main() {
userService := models.UserService{DB: db}
sessionService := models.SessionService{DB: db}
- _ = models.NewEmailService(email_host, email_port, email_username, email_pass, email_sender)
- var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService)
+ emailService := models.NewEmailService(cfg.Email.Host, cfg.Email.Port, cfg.Email.Username, cfg.Email.Pass, cfg.Email.Sender)
+ var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService, emailService)
umw := userctx.UserMiddleware{SS: &sessionService}
r := chi.NewRouter()
r.Use(middleware.Logger)
- r.Use(csrf.Protect(csrfKey, csrf.Secure(!DEBUG)))
+ r.Use(csrf.Protect(cfg.Csrf.Key, csrf.Secure(cfg.Csrf.Secure)))
r.Use(umw.SetUser)
r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml"))
@@ -106,11 +142,14 @@ func main() {
r.Get("/signin", usersCtrlr.GetSignin)
r.Post("/signin", usersCtrlr.PostSignin)
r.Post("/signout", usersCtrlr.GetSignout)
+ r.Get("/forgot-pw", usersCtrlr.GetForgotPassword)
+ r.Post("/forgot-pw", usersCtrlr.PostForgotPassword)
+ r.Get("/reset-pw", usersCtrlr.GetResetPass)
+ r.Post("/reset-pw", usersCtrlr.PostResetPass)
- //r.Get("/user", usersCtrlr.CurrentUser)
r.Get("/user", umw.RequireUserfn(usersCtrlr.CurrentUser))
r.NotFound(notFoundHandler)
- fmt.Println("Starting the server on :3000...")
- http.ListenAndServe(":3000", r)
+ fmt.Printf("Starting the server on %s...\n", cfg.Server.Address)
+ http.ListenAndServe(cfg.Server.Address, r)
}
diff --git a/migrations/00003_password_reset.sql b/migrations/00003_password_reset.sql
new file mode 100644
index 0000000..c89e9e9
--- /dev/null
+++ b/migrations/00003_password_reset.sql
@@ -0,0 +1,14 @@
+-- +goose Up
+-- +goose StatementBegin
+CREATE TABLE password_resets (
+ id SERIAL PRIMARY KEY,
+ user_id INT UNIQUE REFERENCES users (id) ON DELETE CASCADE,
+ token_hash TEXT UNIQUE NOT NULL,
+ expires_at TIMESTAMPTZ NOT NULL
+);
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+DROP TABLE password_resets;
+-- +goose StatementEnd
diff --git a/models/email.go b/models/email.go
index b4080dc..0f7a12c 100644
--- a/models/email.go
+++ b/models/email.go
@@ -49,3 +49,17 @@ func (es *EmailService) Send(email Email) error {
}
return nil
}
+
+func (es *EmailService) SendPasswordReset(to, resetURL string) error {
+ email := Email{
+ Subject: "Reset your password",
+ To: to,
+ Text: "To reset your password, please visit the following link: " + resetURL,
+ Html: `To reset your password, please visit the following link: ` + resetURL + `
`,
+ }
+ err := es.Send(email)
+ if err != nil {
+ return fmt.Errorf("forgot password email: %w", err)
+ }
+ return nil
+}
diff --git a/models/password_reset.go b/models/password_reset.go
new file mode 100644
index 0000000..ba44f99
--- /dev/null
+++ b/models/password_reset.go
@@ -0,0 +1,118 @@
+package models
+
+import (
+ "crypto/sha256"
+ "database/sql"
+ "encoding/base64"
+ "fmt"
+ "strings"
+ "time"
+)
+
+const (
+ resetTokenBytes int = 32
+)
+
+func resetToken() (string, error) {
+ return RandString(resetTokenBytes)
+}
+func hashToken(token string) string {
+ tokenHash := sha256.Sum256([]byte(token))
+ return base64.URLEncoding.EncodeToString(tokenHash[:])
+}
+
+type PasswordReset struct {
+ ID int
+ UserID int
+ // Token is only set when a PasswordReset is being created.
+ Token string
+ TokenHash string
+ ExpiresAt time.Time
+}
+
+type PasswordResetService struct {
+ DB *sql.DB
+ BytesPerToken int
+ Duration time.Duration
+}
+
+func (service *PasswordResetService) delete(id int) error {
+ _, err := service.DB.Exec(`
+ DELETE FROM password_resets
+ WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("delete: %w", err)
+ }
+ return nil
+}
+
+func (service *PasswordResetService) Create(email string) (*PasswordReset, error) {
+ // Verify we have a valid email address for user
+ email = strings.ToLower(email)
+ var UserID int
+ row := service.DB.QueryRow(`
+ SELECT id FROM users WHERE email = $1;
+ `, email)
+ err := row.Scan(&UserID)
+ if err != nil {
+ return nil, fmt.Errorf("Create: %w", err)
+ }
+
+ token, err := resetToken()
+ if err != nil {
+ return nil, fmt.Errorf("Create: %w", err)
+ }
+
+ duration := service.Duration
+ if duration == 0 {
+ duration = time.Hour
+ }
+
+ pwReset := PasswordReset{
+ UserID: UserID,
+ Token: token,
+ TokenHash: hashToken(token),
+ ExpiresAt: time.Now().Add(duration),
+ }
+
+ row = service.DB.QueryRow(`
+ INSERT INTO password_resets (user_id, token_hash, expires_at)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (user_id) DO UPDATE
+ SET token_hash = $2, expires_at = $3
+ RETURNING id;
+ `, pwReset.UserID, pwReset.TokenHash, pwReset.ExpiresAt)
+ err = row.Scan(&pwReset.ID)
+
+ if err != nil {
+ return nil, fmt.Errorf("Create: %w", err)
+ }
+ return &pwReset, nil
+}
+
+// We are going to consume a token and return the user associated with it, or return an error if the token wasn't valid for any reason.
+func (service *PasswordResetService) Consume(token string) (*User, error) {
+ var pwReset PasswordReset
+ var user User
+ pwReset.TokenHash = hashToken(token)
+
+ row := service.DB.QueryRow(`
+ SELECT password_resets.id, password_resets.expires_at, users.id, users.email, users.password_hash
+ FROM password_resets JOIN users ON users.id = password_resets.user_id
+ WHERE password_resets.token_hash = $1;
+ `, pwReset.TokenHash)
+ err := row.Scan(&pwReset.ID, &pwReset.ExpiresAt, &user.ID, &user.Email, &user.PasswordHash)
+ if err != nil {
+ return nil, fmt.Errorf("Consume: %w", err)
+ }
+
+ if time.Now().After(pwReset.ExpiresAt) {
+ return nil, fmt.Errorf("Invalid token")
+ }
+
+ err = service.delete(pwReset.ID)
+ if err != nil {
+ return nil, fmt.Errorf("consume: %w", err)
+ }
+ return &user, nil
+}
diff --git a/models/user.go b/models/user.go
index 286dd91..9baee2a 100644
--- a/models/user.go
+++ b/models/user.go
@@ -42,6 +42,22 @@ func (us *UserService) Create(email, password string) (*User, error) {
return &user, nil
}
+func (us *UserService) UpdatePassword(userID int, password string) error {
+ hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("update password: %w", err)
+ }
+ passwordHash := string(hashedBytes)
+ _, err = us.DB.Exec(`
+ UPDATE users
+ SET password_hash = $2
+ WHERE id = $1;`, userID, passwordHash)
+ if err != nil {
+ return fmt.Errorf("update password: %w", err)
+ }
+ return nil
+}
+
func (us UserService) Authenticate(email, password string) (*User, error) {
user := User{
Email: strings.ToLower(email),
diff --git a/templates/pwChange.gohtml b/templates/pwChange.gohtml
new file mode 100644
index 0000000..d1c9bc4
--- /dev/null
+++ b/templates/pwChange.gohtml
@@ -0,0 +1,98 @@
+
+
+ {{template "head" .}}
+
+ {{template "header".}}
+
+
+
+
+ Reset your password
+
+
+
+
+
+ {{template "footer" .}}
+
+
diff --git a/templates/pwReset.gohtml b/templates/pwReset.gohtml
new file mode 100644
index 0000000..04ad39b
--- /dev/null
+++ b/templates/pwReset.gohtml
@@ -0,0 +1,74 @@
+
+
+ {{template "head" .}}
+
+ {{template "header".}}
+
+
+
+
+ Forgot your password?
+
+
No problem. Enter your email address and we'll send you a link to reset your password.
+
+
+
+
+ {{template "footer" .}}
+
+
diff --git a/templates/pwResetSent.gohtml b/templates/pwResetSent.gohtml
new file mode 100644
index 0000000..b31db26
--- /dev/null
+++ b/templates/pwResetSent.gohtml
@@ -0,0 +1,18 @@
+
+
+ {{template "head" .}}
+
+ {{template "header".}}
+
+
+
+
+ Check your email
+
+
An email has been sent to the email address {{.Email}} with instructions to reset your password.
+
+
+
+ {{template "footer" .}}
+
+
diff --git a/templates/signin.gohtml b/templates/signin.gohtml
index a999ed0..e635c59 100644
--- a/templates/signin.gohtml
+++ b/templates/signin.gohtml
@@ -54,7 +54,7 @@
Sign up
- Forgot your password?
+ Forgot your password?
diff --git a/templates/signup.gohtml b/templates/signup.gohtml
index 106f9aa..e5db07f 100644
--- a/templates/signup.gohtml
+++ b/templates/signup.gohtml
@@ -41,7 +41,7 @@
Sign in
- Forgot your password?
+ Forgot your password?