Add csrf protection

This commit is contained in:
Lucas Schumacher 2024-08-11 20:23:43 -04:00
parent faf9139d79
commit de681c1ac3
6 changed files with 27 additions and 3 deletions

View File

@ -2,12 +2,14 @@ package controllers
import ( import (
"fmt" "fmt"
"html/template"
"net/http" "net/http"
"strings" "strings"
"git.kealoha.me/lks/lenslocked/models" "git.kealoha.me/lks/lenslocked/models"
"git.kealoha.me/lks/lenslocked/templates" "git.kealoha.me/lks/lenslocked/templates"
"git.kealoha.me/lks/lenslocked/views" "git.kealoha.me/lks/lenslocked/views"
"github.com/gorilla/csrf"
) )
type Users struct { type Users struct {
@ -21,8 +23,10 @@ type Users struct {
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
var data struct { var data struct {
Email string Email string
CSRFField template.HTML
} }
data.Email = r.FormValue("email") data.Email = r.FormValue("email")
data.CSRFField = csrf.TemplateField(r)
u.Templates.Signup.Execute(w, data) u.Templates.Signup.Execute(w, data)
} }
@ -41,8 +45,10 @@ func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) {
func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) { func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) {
var data struct { var data struct {
Email string Email string
CSRFField template.HTML
} }
data.Email = r.FormValue("email") data.Email = r.FormValue("email")
data.CSRFField = csrf.TemplateField(r)
u.Templates.Signin.Execute(w, data) u.Templates.Signin.Execute(w, data)
} }
func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) { func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {

2
go.mod
View File

@ -10,6 +10,8 @@ require (
) )
require ( require (
github.com/gorilla/csrf v1.7.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgconn v1.14.3 // indirect
github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgio v1.0.0 // indirect

4
go.sum
View File

@ -19,6 +19,10 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI=
github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=

12
main.go
View File

@ -9,11 +9,14 @@ import (
ctrlrs "git.kealoha.me/lks/lenslocked/controllers" ctrlrs "git.kealoha.me/lks/lenslocked/controllers"
"git.kealoha.me/lks/lenslocked/models" "git.kealoha.me/lks/lenslocked/models"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/gorilla/csrf"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
_ "github.com/jackc/pgx/v4/stdlib" _ "github.com/jackc/pgx/v4/stdlib"
) )
const DEBUG bool = true
func notFoundHandler(w http.ResponseWriter, r *http.Request) { func notFoundHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf8") w.Header().Set("Content-Type", "text/html; charset=utf8")
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
@ -33,6 +36,11 @@ func ConnectDB() *sql.DB {
} }
func main() { func main() {
csrfKey := []byte(os.Getenv("LENSLOCKED_CSRF_KEY"))
if len(csrfKey) < 32 {
panic("Error: bad csrf protection key")
}
db := ConnectDB() db := ConnectDB()
defer db.Close() defer db.Close()
@ -40,7 +48,9 @@ func main() {
var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService) var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.Logger) r.Use(middleware.Logger)
r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml")) r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml"))
r.Get("/contact", ctrlrs.StaticController("contact.gohtml", "tailwind.gohtml")) r.Get("/contact", ctrlrs.StaticController("contact.gohtml", "tailwind.gohtml"))
r.Get("/faq", ctrlrs.FAQ("faq.gohtml", "tailwind.gohtml")) r.Get("/faq", ctrlrs.FAQ("faq.gohtml", "tailwind.gohtml"))
@ -54,5 +64,5 @@ func main() {
r.NotFound(notFoundHandler) r.NotFound(notFoundHandler)
fmt.Println("Starting the server on :3000...") fmt.Println("Starting the server on :3000...")
http.ListenAndServe(":3000", r) http.ListenAndServe(":3000", csrf.Protect(csrfKey, csrf.Secure(!DEBUG))(r))
} }

View File

@ -9,6 +9,7 @@
Welcome back! Welcome back!
</h1> </h1>
<form action="/signin" method="post"> <form action="/signin" method="post">
{{.CSRFField}}
<div class="py-2"> <div class="py-2">
<label for="email" class="text-sm font-semibold text-gray-800"> <label for="email" class="text-sm font-semibold text-gray-800">
Email Address Email Address

View File

@ -10,6 +10,7 @@
Sign Up! Sign Up!
</h1> </h1>
<form action="/signup" method="post"> <form action="/signup" method="post">
{{.CSRFField}}
<div> <div>
<label for="signupEmail" class="text-sm font-semibold text-gray-800">Email Address</label> <label for="signupEmail" class="text-sm font-semibold text-gray-800">Email Address</label>
<input name="email" id="signupEmail" type="email" placeholder="Email address" required autocomplete="email" <input name="email" id="signupEmail" type="email" placeholder="Email address" required autocomplete="email"