Compare commits

...

5 Commits

6 changed files with 203 additions and 17 deletions

View File

@@ -14,7 +14,8 @@ type Users struct {
Signup Template Signup Template
Signin Template Signin Template
} }
UserService *models.UserService UserService *models.UserService
SessionService *models.SessionService
} }
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
@@ -34,7 +35,22 @@ func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }
fmt.Fprintf(w, "User created: %+v", user) session, err := u.SessionService.Create(user.ID)
if err != nil {
fmt.Println(err)
http.Redirect(w, r, "/signin", http.StatusFound)
return
}
cookie := http.Cookie{
Name: "session",
Value: session.Token,
Path: "/",
HttpOnly: true,
}
http.SetCookie(w, &cookie)
http.Redirect(w, r, "/users/me", http.StatusFound)
} }
func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) {
@@ -58,10 +74,16 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {
return return
} }
// Bad cookie session, err := u.SessionService.Create(user.ID)
if err != nil {
fmt.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
cookie := http.Cookie{ cookie := http.Cookie{
Name: "bad", Name: "session",
Value: user.Email, Value: session.Token,
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
} }
@@ -70,24 +92,52 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User authenticated: %+v", user) fmt.Fprintf(w, "User authenticated: %+v", user)
} }
func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) {
email, err := r.Cookie("bad") sessionCookie, err := r.Cookie("session")
if err != nil { if err != nil {
fmt.Fprint(w, "The bad cookie could not be read.") http.Redirect(w, r, "/signin", http.StatusFound)
return return
} }
fmt.Fprintf(w, "Bad cookie: %s\n", email.Value) err = u.SessionService.Delete(sessionCookie.Value)
if err != nil {
fmt.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
c := http.Cookie{
Name: "session",
MaxAge: -1,
}
http.SetCookie(w, &c)
http.Redirect(w, r, "/signin", http.StatusFound)
} }
func WithTemplates(user_service *models.UserService, signup Template, signin Template) Users { func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) {
seshCookie, err := r.Cookie("session")
if err != nil {
fmt.Println(err)
http.Redirect(w, r, "/signin", http.StatusFound)
return
}
user, err := u.SessionService.User(seshCookie.Value)
if err != nil {
fmt.Println(err)
http.Redirect(w, r, "/signin", http.StatusFound)
return
}
fmt.Fprintf(w, "Current user: %s\n", user.Email)
}
func WithTemplates(user_service *models.UserService, session_service *models.SessionService, signup Template, signin Template) Users {
u := Users{} u := Users{}
u.Templates.Signup = signup u.Templates.Signup = signup
u.Templates.Signin = signin u.Templates.Signin = signin
u.UserService = user_service u.UserService = user_service
u.SessionService = session_service
return u return u
} }
func Default(user_service *models.UserService, templatePath ...string) Users { func Default(user_service *models.UserService, session_service *models.SessionService) Users {
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml")) signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml")) signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
@@ -100,5 +150,5 @@ func Default(user_service *models.UserService, templatePath ...string) Users {
panic(err) panic(err)
} }
return WithTemplates(user_service, signup_tpl, signin_tpl) return WithTemplates(user_service, session_service, signup_tpl, signin_tpl)
} }

View File

@@ -38,14 +38,15 @@ func ConnectDB() *sql.DB {
func main() { func main() {
csrfKey := []byte(os.Getenv("LENSLOCKED_CSRF_KEY")) csrfKey := []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
if len(csrfKey) < 32 { if len(csrfKey) < 32 {
panic("Error: bad csrf protection key") panic("Error: bad csrf protection key\nPlease set a key with the LENSLOCKED_CSRF_KEY env var.")
} }
db := ConnectDB() db := ConnectDB()
defer db.Close() defer db.Close()
userService := models.UserService{DB: db} userService := models.UserService{DB: db}
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService) sessionService := models.SessionService{DB: db}
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService)
r := chi.NewRouter() r := chi.NewRouter()
@@ -59,6 +60,7 @@ func main() {
r.Post("/signup", usersCtrlr.PostSignup) r.Post("/signup", usersCtrlr.PostSignup)
r.Get("/signin", usersCtrlr.GetSignin) r.Get("/signin", usersCtrlr.GetSignin)
r.Post("/signin", usersCtrlr.PostSignin) r.Post("/signin", usersCtrlr.PostSignin)
r.Post("/signout", usersCtrlr.GetSignout)
r.Get("/user", usersCtrlr.CurrentUser) r.Get("/user", usersCtrlr.CurrentUser)

116
models/sessions.go Normal file
View File

@@ -0,0 +1,116 @@
package models
import (
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt"
)
func RandBytes(n int) ([]byte, error) {
b := make([]byte, n)
nRead, err := rand.Read(b)
if err != nil {
return nil, fmt.Errorf("bytes: %w", err)
}
if nRead < n {
return nil, fmt.Errorf("bytes: didn't read enough random bytes")
}
return b, nil
}
func RandString(n int) (string, error) {
b, err := RandBytes(n)
if err != nil {
return "", fmt.Errorf("string: %w", err)
}
return base64.URLEncoding.EncodeToString(b), nil
}
const SessionTokenBytes = 32
func SessionToken() (string, error) {
return RandString(SessionTokenBytes)
}
func hash(token string) string {
tokenHash := sha256.Sum256([]byte(token))
return base64.StdEncoding.EncodeToString(tokenHash[:])
}
type Session struct {
ID int
UserID int
TokenHash string
// Token is only set when creating a new session. When looking up a session
// this will be left empty, as we only store the hash of a session token
// in our database and we cannot reverse it into a raw token.
Token string
}
type SessionService struct {
DB *sql.DB
}
func (ss *SessionService) Create(userID int) (*Session, error) {
token, err := SessionToken()
if err != nil {
return nil, fmt.Errorf("create: %w", err)
}
session := Session{
UserID: userID,
Token: token,
TokenHash: hash(token),
}
row := ss.DB.QueryRow(`
UPDATE sessions
SET token_hash = $2
WHERE user_id = $1
RETURNING id;
`, session.UserID, session.TokenHash)
err = row.Scan(&session.ID)
if err == sql.ErrNoRows {
row = ss.DB.QueryRow(`
INSERT INTO sessions (user_id, token_hash)
VALUES ($1, $2)
RETURNING id;
`, session.UserID, session.TokenHash)
err = row.Scan(&session.ID)
}
if err != nil {
return nil, fmt.Errorf("create: %w", err)
}
return &session, nil
}
func (ss *SessionService) Delete(token string) error {
tokenHash := hash(token)
_, err := ss.DB.Exec(`DELETE FROM sessions WHERE token_hash = $1;`, tokenHash)
if err != nil {
return fmt.Errorf("delete: %w", err)
}
return nil
}
func (ss *SessionService) User(token string) (*User, error) {
token_hash := hash(token)
var user User
row := ss.DB.QueryRow(`
SELECT (user_id) FROM sessions WHERE token_hash = $1;
`, token_hash)
err := row.Scan(&user.ID)
if err != nil {
return nil, fmt.Errorf("user: %w", err)
}
row = ss.DB.QueryRow(`
SELECT email, password_hash
FROM users WHERE id = $1;
`, user.ID)
err = row.Scan(&user.Email, &user.PasswordHash)
if err != nil {
return nil, fmt.Errorf("user: %w", err)
}
return &user, err
}

5
models/sql/sessions.sql Normal file
View File

@@ -0,0 +1,5 @@
CREATE TABLE sessions (
id SERIAL PRIMARY KEY,
user_id INT UNIQUE,
token_hash TEXT UNIQUE NOT NULL
);

View File

@@ -16,6 +16,10 @@
<a class="text-base font-semibold hover:text-blue-100 pr-8" href="/faq">FAQ</a> <a class="text-base font-semibold hover:text-blue-100 pr-8" href="/faq">FAQ</a>
</div> </div>
<div class="space-x-4"> <div class="space-x-4">
<form action="/signout" method="post" class="inline pr-4">
{{csrfField}}
<button type="submit">Sign out</button>
</form>
<a href="/signin">Sign in</a> <a href="/signin">Sign in</a>
<a href="/signup" clss="px-4 py-2 bg-blue-700 hover:bg-blue-600 rounded">Sign up</a> <a href="/signup" clss="px-4 py-2 bg-blue-700 hover:bg-blue-600 rounded">Sign up</a>
</div> </div>

View File

@@ -1,8 +1,10 @@
package views package views
import ( import (
"bytes"
"fmt" "fmt"
"html/template" "html/template"
"io"
"io/fs" "io/fs"
"log" "log"
"net/http" "net/http"
@@ -29,12 +31,14 @@ func (t Template) Execute(w http.ResponseWriter, r *http.Request, data interface
}) })
w.Header().Set("Content-Type", "text/html; charset=utf8") w.Header().Set("Content-Type", "text/html; charset=utf8")
err = tpl.Execute(w, data) var buf bytes.Buffer
err = tpl.Execute(&buf, data)
if err != nil { if err != nil {
log.Printf("Error executing template: %v", err) log.Printf("Error executing template: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }
io.Copy(w, &buf)
} }
func (t Template) TestTemplate(data interface{}) error { func (t Template) TestTemplate(data interface{}) error {
var testWriter strings.Builder var testWriter strings.Builder
@@ -42,6 +46,11 @@ func (t Template) TestTemplate(data interface{}) error {
if err != nil { if err != nil {
return err return err
} }
tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML {
return `<div class="hidden">STUB: PLACEHOLDER</div>`
},
})
return tpl.Execute(&testWriter, data) return tpl.Execute(&testWriter, data)
} }
@@ -52,8 +61,8 @@ func FromFile(pattern ...string) (Template, error) {
func FromFS(fs fs.FS, pattern ...string) (Template, error) { func FromFS(fs fs.FS, pattern ...string) (Template, error) {
tpl := template.New(pattern[0]) tpl := template.New(pattern[0])
tpl = tpl.Funcs(template.FuncMap{ tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML { "csrfField": func() (template.HTML, error) {
return `<div class="hidden">STUB: PLACEHOLDER</div>` return `<div class="hidden">STUB: PLACEHOLDER</div>`, fmt.Errorf("csrfField Not Implimented")
}, },
}) })
tpl, err := tpl.ParseFS(fs, pattern...) tpl, err := tpl.ParseFS(fs, pattern...)