Add user sessions
This commit is contained in:
parent
87cae430a3
commit
dfde1b8381
@ -14,7 +14,8 @@ type Users struct {
|
|||||||
Signup Template
|
Signup Template
|
||||||
Signin Template
|
Signin Template
|
||||||
}
|
}
|
||||||
UserService *models.UserService
|
UserService *models.UserService
|
||||||
|
SessionService *models.SessionService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
|
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -34,7 +35,22 @@ func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Fprintf(w, "User created: %+v", user)
|
session, err := u.SessionService.Create(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
http.Redirect(w, r, "/signin", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := http.Cookie{
|
||||||
|
Name: "session",
|
||||||
|
Value: session.Token,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &cookie)
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/users/me", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) {
|
func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -58,10 +74,16 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bad cookie
|
session, err := u.SessionService.Create(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
cookie := http.Cookie{
|
cookie := http.Cookie{
|
||||||
Name: "bad",
|
Name: "session",
|
||||||
Value: user.Email,
|
Value: session.Token,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
}
|
}
|
||||||
@ -71,23 +93,31 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
|
func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||||
email, err := r.Cookie("bad")
|
seshCookie, err := r.Cookie("session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprint(w, "The bad cookie could not be read.")
|
fmt.Println(err)
|
||||||
|
http.Redirect(w, r, "/signin", http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Fprintf(w, "Bad cookie: %s\n", email.Value)
|
user, err := u.SessionService.User(seshCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
http.Redirect(w, r, "/signin", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "Current user: %s\n", user.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithTemplates(user_service *models.UserService, signup Template, signin Template) Users {
|
func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users {
|
||||||
u := Users{}
|
u := Users{}
|
||||||
u.Templates.Signup = signup
|
u.Templates.Signup = signup
|
||||||
u.Templates.Signin = signin
|
u.Templates.Signin = signin
|
||||||
u.UserService = user_service
|
u.UserService = user_service
|
||||||
|
u.SessionService = session_service
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func Default(user_service *models.UserService, templatePath ...string) Users {
|
func Default(user_service *models.UserService, session_service *models.SessionService) Users {
|
||||||
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
|
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
|
||||||
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
|
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
|
||||||
|
|
||||||
@ -100,5 +130,5 @@ func Default(user_service *models.UserService, templatePath ...string) Users {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return WithTemplates(user_service, signup_tpl, signin_tpl)
|
return WithTemplates(user_service, session_service, signup_tpl, signin_tpl)
|
||||||
}
|
}
|
||||||
|
|||||||
3
main.go
3
main.go
@ -45,7 +45,8 @@ func main() {
|
|||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
userService := models.UserService{DB: db}
|
userService := models.UserService{DB: db}
|
||||||
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService)
|
sessionService := models.SessionService{DB: db}
|
||||||
|
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
|||||||
107
models/sessions.go
Normal file
107
models/sessions.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RandBytes(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
nRead, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bytes: %w", err)
|
||||||
|
}
|
||||||
|
if nRead < n {
|
||||||
|
return nil, fmt.Errorf("bytes: didn't read enough random bytes")
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
func RandString(n int) (string, error) {
|
||||||
|
b, err := RandBytes(n)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("string: %w", err)
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const SessionTokenBytes = 32
|
||||||
|
|
||||||
|
func SessionToken() (string, error) {
|
||||||
|
return RandString(SessionTokenBytes)
|
||||||
|
}
|
||||||
|
func hash(token string) string {
|
||||||
|
tokenHash := sha256.Sum256([]byte(token))
|
||||||
|
return base64.StdEncoding.EncodeToString(tokenHash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID int
|
||||||
|
UserID int
|
||||||
|
TokenHash string
|
||||||
|
// Token is only set when creating a new session. When looking up a session
|
||||||
|
// this will be left empty, as we only store the hash of a session token
|
||||||
|
// in our database and we cannot reverse it into a raw token.
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionService struct {
|
||||||
|
DB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *SessionService) Create(userID int) (*Session, error) {
|
||||||
|
token, err := SessionToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create: %w", err)
|
||||||
|
}
|
||||||
|
session := Session{
|
||||||
|
UserID: userID,
|
||||||
|
Token: token,
|
||||||
|
TokenHash: hash(token),
|
||||||
|
}
|
||||||
|
|
||||||
|
row := ss.DB.QueryRow(`
|
||||||
|
UPDATE sessions
|
||||||
|
SET token_hash = $2
|
||||||
|
WHERE user_id = $1
|
||||||
|
RETURNING id;
|
||||||
|
`, session.UserID, session.TokenHash)
|
||||||
|
err = row.Scan(&session.ID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
row = ss.DB.QueryRow(`
|
||||||
|
INSERT INTO sessions (user_id, token_hash)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id;
|
||||||
|
`, session.UserID, session.TokenHash)
|
||||||
|
err = row.Scan(&session.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *SessionService) User(token string) (*User, error) {
|
||||||
|
token_hash := hash(token)
|
||||||
|
var user User
|
||||||
|
row := ss.DB.QueryRow(`
|
||||||
|
SELECT (user_id) FROM sessions WHERE token_hash = $1;
|
||||||
|
`, token_hash)
|
||||||
|
err := row.Scan(&user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("user: %w", err)
|
||||||
|
}
|
||||||
|
row = ss.DB.QueryRow(`
|
||||||
|
SELECT email, password_hash
|
||||||
|
FROM users WHERE id = $1;
|
||||||
|
`, user.ID)
|
||||||
|
err = row.Scan(&user.Email, &user.PasswordHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("user: %w", err)
|
||||||
|
}
|
||||||
|
return &user, err
|
||||||
|
}
|
||||||
5
models/sql/sessions.sql
Normal file
5
models/sql/sessions.sql
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
CREATE TABLE sessions (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
user_id INT UNIQUE,
|
||||||
|
token_hash TEXT UNIQUE NOT NULL
|
||||||
|
);
|
||||||
Loading…
x
Reference in New Issue
Block a user