From dfde1b838199e22b012c3d551a8b0a281ba91e30 Mon Sep 17 00:00:00 2001 From: Lucas Schumacher Date: Wed, 21 Aug 2024 23:34:10 -0400 Subject: [PATCH] Add user sessions --- controllers/users.go | 52 ++++++++++++++----- main.go | 3 +- models/sessions.go | 107 ++++++++++++++++++++++++++++++++++++++++ models/sql/sessions.sql | 5 ++ 4 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 models/sessions.go create mode 100644 models/sql/sessions.sql diff --git a/controllers/users.go b/controllers/users.go index 882bcd6..8b6ef03 100644 --- a/controllers/users.go +++ b/controllers/users.go @@ -14,7 +14,8 @@ type Users struct { Signup Template Signin Template } - UserService *models.UserService + UserService *models.UserService + SessionService *models.SessionService } func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { @@ -34,7 +35,22 @@ func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - fmt.Fprintf(w, "User created: %+v", user) + session, err := u.SessionService.Create(user.ID) + if err != nil { + fmt.Println(err) + http.Redirect(w, r, "/signin", http.StatusFound) + return + } + + cookie := http.Cookie{ + Name: "session", + Value: session.Token, + Path: "/", + HttpOnly: true, + } + http.SetCookie(w, &cookie) + + http.Redirect(w, r, "/users/me", http.StatusFound) } func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) { @@ -58,10 +74,16 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) { return } - // Bad cookie + session, err := u.SessionService.Create(user.ID) + if err != nil { + fmt.Println(err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + cookie := http.Cookie{ - Name: "bad", - Value: user.Email, + Name: "session", + Value: session.Token, Path: "/", HttpOnly: true, } @@ -71,23 +93,31 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) { } func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { - email, err := r.Cookie("bad") + seshCookie, err := r.Cookie("session") if err != nil { - fmt.Fprint(w, "The bad cookie could not be read.") + fmt.Println(err) + http.Redirect(w, r, "/signin", http.StatusFound) return } - fmt.Fprintf(w, "Bad cookie: %s\n", email.Value) + user, err := u.SessionService.User(seshCookie.Value) + if err != nil { + fmt.Println(err) + http.Redirect(w, r, "/signin", http.StatusFound) + return + } + fmt.Fprintf(w, "Current user: %s\n", user.Email) } -func WithTemplates(user_service *models.UserService, signup Template, signin Template) Users { +func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users { u := Users{} u.Templates.Signup = signup u.Templates.Signin = signin u.UserService = user_service + u.SessionService = session_service return u } -func Default(user_service *models.UserService, templatePath ...string) Users { +func Default(user_service *models.UserService, session_service *models.SessionService) Users { signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml")) signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml")) @@ -100,5 +130,5 @@ func Default(user_service *models.UserService, templatePath ...string) Users { panic(err) } - return WithTemplates(user_service, signup_tpl, signin_tpl) + return WithTemplates(user_service, session_service, signup_tpl, signin_tpl) } diff --git a/main.go b/main.go index 69e9193..3d0df4b 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,8 @@ func main() { defer db.Close() userService := models.UserService{DB: db} - var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService) + sessionService := models.SessionService{DB: db} + var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService) r := chi.NewRouter() diff --git a/models/sessions.go b/models/sessions.go new file mode 100644 index 0000000..604e1b7 --- /dev/null +++ b/models/sessions.go @@ -0,0 +1,107 @@ +package models + +import ( + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" +) + +func RandBytes(n int) ([]byte, error) { + b := make([]byte, n) + nRead, err := rand.Read(b) + if err != nil { + return nil, fmt.Errorf("bytes: %w", err) + } + if nRead < n { + return nil, fmt.Errorf("bytes: didn't read enough random bytes") + } + return b, nil +} +func RandString(n int) (string, error) { + b, err := RandBytes(n) + if err != nil { + return "", fmt.Errorf("string: %w", err) + } + return base64.URLEncoding.EncodeToString(b), nil +} + +const SessionTokenBytes = 32 + +func SessionToken() (string, error) { + return RandString(SessionTokenBytes) +} +func hash(token string) string { + tokenHash := sha256.Sum256([]byte(token)) + return base64.StdEncoding.EncodeToString(tokenHash[:]) +} + +type Session struct { + ID int + UserID int + TokenHash string + // Token is only set when creating a new session. When looking up a session + // this will be left empty, as we only store the hash of a session token + // in our database and we cannot reverse it into a raw token. + Token string +} + +type SessionService struct { + DB *sql.DB +} + +func (ss *SessionService) Create(userID int) (*Session, error) { + token, err := SessionToken() + if err != nil { + return nil, fmt.Errorf("create: %w", err) + } + session := Session{ + UserID: userID, + Token: token, + TokenHash: hash(token), + } + + row := ss.DB.QueryRow(` + UPDATE sessions + SET token_hash = $2 + WHERE user_id = $1 + RETURNING id; + `, session.UserID, session.TokenHash) + err = row.Scan(&session.ID) + if err == sql.ErrNoRows { + row = ss.DB.QueryRow(` + INSERT INTO sessions (user_id, token_hash) + VALUES ($1, $2) + RETURNING id; + `, session.UserID, session.TokenHash) + err = row.Scan(&session.ID) + } + + if err != nil { + return nil, fmt.Errorf("create: %w", err) + } + + return &session, nil +} + +func (ss *SessionService) User(token string) (*User, error) { + token_hash := hash(token) + var user User + row := ss.DB.QueryRow(` + SELECT (user_id) FROM sessions WHERE token_hash = $1; + `, token_hash) + err := row.Scan(&user.ID) + if err != nil { + return nil, fmt.Errorf("user: %w", err) + } + row = ss.DB.QueryRow(` + SELECT email, password_hash + FROM users WHERE id = $1; + `, user.ID) + err = row.Scan(&user.Email, &user.PasswordHash) + if err != nil { + return nil, fmt.Errorf("user: %w", err) + } + return &user, err +} diff --git a/models/sql/sessions.sql b/models/sql/sessions.sql new file mode 100644 index 0000000..aa7d099 --- /dev/null +++ b/models/sql/sessions.sql @@ -0,0 +1,5 @@ +CREATE TABLE sessions ( + id SERIAL PRIMARY KEY, + user_id INT UNIQUE, + token_hash TEXT UNIQUE NOT NULL +);