diff --git a/controllers/users.go b/controllers/users.go index 7e9314b..0d4122e 100644 --- a/controllers/users.go +++ b/controllers/users.go @@ -3,6 +3,8 @@ package controllers import ( "fmt" "net/http" + "net/url" + "time" userctx "git.kealoha.me/lks/lenslocked/context" "git.kealoha.me/lks/lenslocked/models" @@ -12,11 +14,16 @@ import ( type Users struct { Templates struct { - Signup Template - Signin Template + Signup Template + Signin Template + ForgotPass Template + ResetUrlSent Template + ResetPass Template } - UserService *models.UserService - SessionService *models.SessionService + UserService *models.UserService + SessionService *models.SessionService + PassResetService *models.PasswordResetService + EmailService *models.EmailService } func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { @@ -114,6 +121,87 @@ func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/signin", http.StatusFound) } +func (u Users) GetForgotPassword(w http.ResponseWriter, r *http.Request) { + var data struct { + Email string + } + data.Email = r.FormValue("email") + u.Templates.ForgotPass.Execute(w, r, data) +} + +func (u Users) PostForgotPassword(w http.ResponseWriter, r *http.Request) { + var data struct { + Email string + } + data.Email = r.FormValue("email") + pwReset, err := u.PassResetService.Create(data.Email) + if err != nil { + fmt.Println(err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + vals := url.Values{ + "token": {pwReset.Token}, + } + // TODO: Make the URL here configurable and use https + resetURL := "http://" + r.Host + "/reset-pw?" + vals.Encode() + fmt.Println(resetURL) + err = u.EmailService.SendPasswordReset(data.Email, resetURL) + if err != nil { + fmt.Println(err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + u.Templates.ResetUrlSent.Execute(w, r, data) +} + +func (u Users) GetResetPass(w http.ResponseWriter, r *http.Request) { + var data struct { + Token string + } + data.Token = r.FormValue("token") + u.Templates.ResetPass.Execute(w, r, data) +} +func (u Users) PostResetPass(w http.ResponseWriter, r *http.Request) { + var data struct { + Token, Password string + } + data.Token = r.FormValue("token") + data.Password = r.FormValue("password") + + user, err := u.PassResetService.Consume(data.Token) + if err != nil { + fmt.Println(err) + // TODO: Distinguish between server errors and invalid token errors. + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + err = u.UserService.UpdatePassword(user.ID, data.Password) + if err != nil { + fmt.Println(err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Sign the user in now that they have reset their password. + // Any errors from this point onward should redirect to the sign in page. + session, err := u.SessionService.Create(user.ID) + if err != nil { + fmt.Println(err) + http.Redirect(w, r, "/signin", http.StatusFound) + return + } + //setCookie(w, CookieSession, session.Token) + cookie := http.Cookie{ + Name: "session", + Value: session.Token, + Path: "/", + HttpOnly: true, + } + http.SetCookie(w, &cookie) + http.Redirect(w, r, "/users/me", http.StatusFound) +} + func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { user := userctx.User(r.Context()) if user == nil { @@ -123,18 +211,32 @@ func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Current user: %s\n", user.Email) } -func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users { +func WithTemplates(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService, signup, signin, forgotPass, resetUrlSent, resetPass Template) Users { u := Users{} + u.Templates.Signup = signup u.Templates.Signin = signin + u.Templates.ForgotPass = forgotPass + u.Templates.ResetUrlSent = resetUrlSent + u.Templates.ResetPass = resetPass + u.UserService = user_service u.SessionService = session_service + u.EmailService = email_service + u.PassResetService = &models.PasswordResetService{ + DB: u.UserService.DB, + Duration: time.Hour / 2, + } + return u } -func Default(user_service *models.UserService, session_service *models.SessionService) Users { +func Default(user_service *models.UserService, session_service *models.SessionService, email_service *models.EmailService) Users { signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml")) signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml")) + pwReset_tpl := views.Must(views.FromFS(templates.FS, "pwReset.gohtml", "tailwind.gohtml")) + pwResetSent_tpl := views.Must(views.FromFS(templates.FS, "pwResetSent.gohtml", "tailwind.gohtml")) + resetPass_tpl := views.Must(views.FromFS(templates.FS, "pwChange.gohtml", "tailwind.gohtml")) err := signup_tpl.TestTemplate(nil) if err != nil { @@ -144,6 +246,14 @@ func Default(user_service *models.UserService, session_service *models.SessionSe if err != nil { panic(err) } + err = pwReset_tpl.TestTemplate(nil) + if err != nil { + panic(err) + } + err = pwResetSent_tpl.TestTemplate(nil) + if err != nil { + panic(err) + } - return WithTemplates(user_service, session_service, signup_tpl, signin_tpl) + return WithTemplates(user_service, session_service, email_service, signup_tpl, signin_tpl, pwReset_tpl, pwResetSent_tpl, resetPass_tpl) } diff --git a/main.go b/main.go index c104834..96f9320 100644 --- a/main.go +++ b/main.go @@ -23,14 +23,68 @@ import ( const DEBUG bool = true +type config struct { + Postgres string + Email struct { + Host string + Port int + Username, Pass, Sender string + } + Csrf struct { + Key []byte + Secure bool + } + Server struct { + Address string + } +} + +func loadConfig() (config, error) { + var cfg config + cfg.Csrf.Secure = !DEBUG + + err := godotenv.Load() + if err != nil { + fmt.Println("Warning: Could not load a .env file") + } + + cfg.Csrf.Key = []byte(os.Getenv("LENSLOCKED_CSRF_KEY")) + if len(cfg.Csrf.Key) < 32 { + return cfg, fmt.Errorf("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.") + } + + cfg.Postgres = os.Getenv("LENSLOCKED_DB_STRING") + + cfg.Email.Host = os.Getenv("LENSLOCKED_EMAIL_HOST") + cfg.Email.Username = os.Getenv("LENSLOCKED_EMAIL_USERNAME") + cfg.Email.Pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD") + cfg.Email.Sender = os.Getenv("LENSLOCKED_EMAIL_FROM") + cfg.Email.Port, err = strconv.Atoi(os.Getenv("LENSLOCKED_EMAIL_PORT")) + if err != nil { + fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587") + cfg.Email.Port = 587 + } + + cfg.Server.Address = os.Getenv("LENSLOCKED_ADDRESS") + if cfg.Server.Address == "" { + if DEBUG { + cfg.Server.Address = ":3000" + } else { + return cfg, fmt.Errorf("No server address set\nPlease set the LENSLOCKED_ADDRESS env var to the servers address") + } + } + + return cfg, nil +} + func notFoundHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf8") w.WriteHeader(http.StatusNotFound) fmt.Fprint(w, "404 page not found") } -func ConnectDB() *sql.DB { - db, err := sql.Open("pgx", os.Getenv("LENSLOCKED_DB_STRING")) +func ConnectDB(dbstr string) *sql.DB { + db, err := sql.Open("pgx", dbstr) if err != nil { panic(fmt.Sprint("Error connecting to database: %w", err)) } @@ -55,29 +109,11 @@ func MigrateDB(db *sql.DB, subfs fs.FS) error { } func main() { - err := godotenv.Load() + cfg, err := loadConfig() if err != nil { - fmt.Println("Warning: Could not load .env file") + panic(err) } - - var ( - email_host = os.Getenv("LENSLOCKED_EMAIL_HOST") - email_port_str = os.Getenv("LENSLOCKED_EMAIL_PORT") - email_username = os.Getenv("LENSLOCKED_EMAIL_USERNAME") - email_pass = os.Getenv("LENSLOCKED_EMAIL_PASSWORD") - email_sender = os.Getenv("LENSLOCKED_EMAIL_FROM") - csrfKey = []byte(os.Getenv("LENSLOCKED_CSRF_KEY")) - ) - if len(csrfKey) < 32 { - panic("Error: no or bad csrf protection key\nPlease set the LENSLOCKED_CSRF_KEY env var to a key at least 32 characters long.") - } - email_port, err := strconv.Atoi(email_port_str) - if err != nil { - fmt.Println("Warning: Invalid STMP port set in LENSLOCKED_EMAIL_PORT. Using port 587") - email_port = 587 - } - - db := ConnectDB() + db := ConnectDB(cfg.Postgres) defer db.Close() err = MigrateDB(db, migrations.FS) if err != nil { @@ -86,15 +122,15 @@ func main() { userService := models.UserService{DB: db} sessionService := models.SessionService{DB: db} - _ = models.NewEmailService(email_host, email_port, email_username, email_pass, email_sender) - var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService) + emailService := models.NewEmailService(cfg.Email.Host, cfg.Email.Port, cfg.Email.Username, cfg.Email.Pass, cfg.Email.Sender) + var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService, emailService) umw := userctx.UserMiddleware{SS: &sessionService} r := chi.NewRouter() r.Use(middleware.Logger) - r.Use(csrf.Protect(csrfKey, csrf.Secure(!DEBUG))) + r.Use(csrf.Protect(cfg.Csrf.Key, csrf.Secure(cfg.Csrf.Secure))) r.Use(umw.SetUser) r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml")) @@ -106,11 +142,14 @@ func main() { r.Get("/signin", usersCtrlr.GetSignin) r.Post("/signin", usersCtrlr.PostSignin) r.Post("/signout", usersCtrlr.GetSignout) + r.Get("/forgot-pw", usersCtrlr.GetForgotPassword) + r.Post("/forgot-pw", usersCtrlr.PostForgotPassword) + r.Get("/reset-pw", usersCtrlr.GetResetPass) + r.Post("/reset-pw", usersCtrlr.PostResetPass) - //r.Get("/user", usersCtrlr.CurrentUser) r.Get("/user", umw.RequireUserfn(usersCtrlr.CurrentUser)) r.NotFound(notFoundHandler) - fmt.Println("Starting the server on :3000...") - http.ListenAndServe(":3000", r) + fmt.Printf("Starting the server on %s...\n", cfg.Server.Address) + http.ListenAndServe(cfg.Server.Address, r) } diff --git a/migrations/00003_password_reset.sql b/migrations/00003_password_reset.sql new file mode 100644 index 0000000..c89e9e9 --- /dev/null +++ b/migrations/00003_password_reset.sql @@ -0,0 +1,14 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE password_resets ( + id SERIAL PRIMARY KEY, + user_id INT UNIQUE REFERENCES users (id) ON DELETE CASCADE, + token_hash TEXT UNIQUE NOT NULL, + expires_at TIMESTAMPTZ NOT NULL +); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE password_resets; +-- +goose StatementEnd diff --git a/models/email.go b/models/email.go index b4080dc..0f7a12c 100644 --- a/models/email.go +++ b/models/email.go @@ -49,3 +49,17 @@ func (es *EmailService) Send(email Email) error { } return nil } + +func (es *EmailService) SendPasswordReset(to, resetURL string) error { + email := Email{ + Subject: "Reset your password", + To: to, + Text: "To reset your password, please visit the following link: " + resetURL, + Html: `

To reset your password, please visit the following link: ` + resetURL + `

`, + } + err := es.Send(email) + if err != nil { + return fmt.Errorf("forgot password email: %w", err) + } + return nil +} diff --git a/models/password_reset.go b/models/password_reset.go new file mode 100644 index 0000000..ba44f99 --- /dev/null +++ b/models/password_reset.go @@ -0,0 +1,118 @@ +package models + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" +) + +const ( + resetTokenBytes int = 32 +) + +func resetToken() (string, error) { + return RandString(resetTokenBytes) +} +func hashToken(token string) string { + tokenHash := sha256.Sum256([]byte(token)) + return base64.URLEncoding.EncodeToString(tokenHash[:]) +} + +type PasswordReset struct { + ID int + UserID int + // Token is only set when a PasswordReset is being created. + Token string + TokenHash string + ExpiresAt time.Time +} + +type PasswordResetService struct { + DB *sql.DB + BytesPerToken int + Duration time.Duration +} + +func (service *PasswordResetService) delete(id int) error { + _, err := service.DB.Exec(` + DELETE FROM password_resets + WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("delete: %w", err) + } + return nil +} + +func (service *PasswordResetService) Create(email string) (*PasswordReset, error) { + // Verify we have a valid email address for user + email = strings.ToLower(email) + var UserID int + row := service.DB.QueryRow(` + SELECT id FROM users WHERE email = $1; + `, email) + err := row.Scan(&UserID) + if err != nil { + return nil, fmt.Errorf("Create: %w", err) + } + + token, err := resetToken() + if err != nil { + return nil, fmt.Errorf("Create: %w", err) + } + + duration := service.Duration + if duration == 0 { + duration = time.Hour + } + + pwReset := PasswordReset{ + UserID: UserID, + Token: token, + TokenHash: hashToken(token), + ExpiresAt: time.Now().Add(duration), + } + + row = service.DB.QueryRow(` + INSERT INTO password_resets (user_id, token_hash, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (user_id) DO UPDATE + SET token_hash = $2, expires_at = $3 + RETURNING id; + `, pwReset.UserID, pwReset.TokenHash, pwReset.ExpiresAt) + err = row.Scan(&pwReset.ID) + + if err != nil { + return nil, fmt.Errorf("Create: %w", err) + } + return &pwReset, nil +} + +// We are going to consume a token and return the user associated with it, or return an error if the token wasn't valid for any reason. +func (service *PasswordResetService) Consume(token string) (*User, error) { + var pwReset PasswordReset + var user User + pwReset.TokenHash = hashToken(token) + + row := service.DB.QueryRow(` + SELECT password_resets.id, password_resets.expires_at, users.id, users.email, users.password_hash + FROM password_resets JOIN users ON users.id = password_resets.user_id + WHERE password_resets.token_hash = $1; + `, pwReset.TokenHash) + err := row.Scan(&pwReset.ID, &pwReset.ExpiresAt, &user.ID, &user.Email, &user.PasswordHash) + if err != nil { + return nil, fmt.Errorf("Consume: %w", err) + } + + if time.Now().After(pwReset.ExpiresAt) { + return nil, fmt.Errorf("Invalid token") + } + + err = service.delete(pwReset.ID) + if err != nil { + return nil, fmt.Errorf("consume: %w", err) + } + return &user, nil +} diff --git a/models/user.go b/models/user.go index 286dd91..9baee2a 100644 --- a/models/user.go +++ b/models/user.go @@ -42,6 +42,22 @@ func (us *UserService) Create(email, password string) (*User, error) { return &user, nil } +func (us *UserService) UpdatePassword(userID int, password string) error { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("update password: %w", err) + } + passwordHash := string(hashedBytes) + _, err = us.DB.Exec(` + UPDATE users + SET password_hash = $2 + WHERE id = $1;`, userID, passwordHash) + if err != nil { + return fmt.Errorf("update password: %w", err) + } + return nil +} + func (us UserService) Authenticate(email, password string) (*User, error) { user := User{ Email: strings.ToLower(email), diff --git a/templates/pwChange.gohtml b/templates/pwChange.gohtml new file mode 100644 index 0000000..d1c9bc4 --- /dev/null +++ b/templates/pwChange.gohtml @@ -0,0 +1,98 @@ + + + {{template "head" .}} + + {{template "header".}} +
+
+
+

+ Reset your password +

+
+ +
+ + +
+ {{if .Token}} + + {{else}} +
+ + +
+ {{end}} + +
+ +
+
+

+ Sign up +

+

+ Sign in +

+
+
+
+
+
+ {{template "footer" .}} + + diff --git a/templates/pwReset.gohtml b/templates/pwReset.gohtml new file mode 100644 index 0000000..04ad39b --- /dev/null +++ b/templates/pwReset.gohtml @@ -0,0 +1,74 @@ + + + {{template "head" .}} + + {{template "header".}} +
+
+
+

+ Forgot your password? +

+

No problem. Enter your email address and we'll send you a link to reset your password.

+
+ +
+ + +
+
+ +
+
+

+ Need an account? + Sign up +

+

+ Remember your password? +

+
+
+
+
+
+ {{template "footer" .}} + + diff --git a/templates/pwResetSent.gohtml b/templates/pwResetSent.gohtml new file mode 100644 index 0000000..b31db26 --- /dev/null +++ b/templates/pwResetSent.gohtml @@ -0,0 +1,18 @@ + + + {{template "head" .}} + + {{template "header".}} +
+
+
+

+ Check your email +

+

An email has been sent to the email address {{.Email}} with instructions to reset your password.

+
+
+
+ {{template "footer" .}} + + diff --git a/templates/signin.gohtml b/templates/signin.gohtml index a999ed0..e635c59 100644 --- a/templates/signin.gohtml +++ b/templates/signin.gohtml @@ -54,7 +54,7 @@ Sign up

- Forgot your password? + Forgot your password?

diff --git a/templates/signup.gohtml b/templates/signup.gohtml index 106f9aa..e5db07f 100644 --- a/templates/signup.gohtml +++ b/templates/signup.gohtml @@ -41,7 +41,7 @@ Sign in

- Forgot your password? + Forgot your password?