package models import ( "crypto/rand" "crypto/sha256" "database/sql" "encoding/base64" "fmt" ) func RandBytes(n int) ([]byte, error) { b := make([]byte, n) nRead, err := rand.Read(b) if err != nil { return nil, fmt.Errorf("bytes: %w", err) } if nRead < n { return nil, fmt.Errorf("bytes: didn't read enough random bytes") } return b, nil } func RandString(n int) (string, error) { b, err := RandBytes(n) if err != nil { return "", fmt.Errorf("string: %w", err) } return base64.URLEncoding.EncodeToString(b), nil } const SessionTokenBytes = 32 func SessionToken() (string, error) { return RandString(SessionTokenBytes) } func hash(token string) string { tokenHash := sha256.Sum256([]byte(token)) return base64.StdEncoding.EncodeToString(tokenHash[:]) } type Session struct { ID int UserID int TokenHash string // Token is only set when creating a new session. When looking up a session // this will be left empty, as we only store the hash of a session token // in our database and we cannot reverse it into a raw token. Token string } type SessionService struct { DB *sql.DB } func (ss *SessionService) Create(userID int) (*Session, error) { token, err := SessionToken() if err != nil { return nil, fmt.Errorf("create: %w", err) } session := Session{ UserID: userID, Token: token, TokenHash: hash(token), } row := ss.DB.QueryRow(` UPDATE sessions SET token_hash = $2 WHERE user_id = $1 RETURNING id; `, session.UserID, session.TokenHash) err = row.Scan(&session.ID) if err == sql.ErrNoRows { row = ss.DB.QueryRow(` INSERT INTO sessions (user_id, token_hash) VALUES ($1, $2) RETURNING id; `, session.UserID, session.TokenHash) err = row.Scan(&session.ID) } if err != nil { return nil, fmt.Errorf("create: %w", err) } return &session, nil } func (ss *SessionService) Delete(token string) error { tokenHash := hash(token) _, err := ss.DB.Exec(`DELETE FROM sessions WHERE token_hash = $1;`, tokenHash) if err != nil { return fmt.Errorf("delete: %w", err) } return nil } func (ss *SessionService) User(token string) (*User, error) { token_hash := hash(token) var user User row := ss.DB.QueryRow(` SELECT (user_id) FROM sessions WHERE token_hash = $1; `, token_hash) err := row.Scan(&user.ID) if err != nil { return nil, fmt.Errorf("user: %w", err) } row = ss.DB.QueryRow(` SELECT email, password_hash FROM users WHERE id = $1; `, user.ID) err = row.Scan(&user.Email, &user.PasswordHash) if err != nil { return nil, fmt.Errorf("user: %w", err) } return &user, err }