Use a custom template function for csrf protection

This commit is contained in:
Lucas Schumacher 2024-08-13 06:58:02 -04:00
parent 8bc58eedbe
commit 4cf50a7d81
5 changed files with 37 additions and 24 deletions

View File

@ -4,7 +4,6 @@ import (
"git.kealoha.me/lks/lenslocked/templates"
"git.kealoha.me/lks/lenslocked/views"
"net/http"
"strings"
)
type Template interface {
@ -14,8 +13,7 @@ type Template interface {
func StaticController(templatePath ...string) http.HandlerFunc {
tpl := views.Must(views.FromFS(templates.FS, templatePath...))
var testWriter strings.Builder
err := tpl.ExecuteWriter(&testWriter, nil)
err := tpl.TestTemplate(nil)
if err != nil {
panic(err)
}
@ -40,8 +38,7 @@ func FAQ(templatePath ...string) http.HandlerFunc {
tpl := views.Must(views.FromFS(templates.FS, templatePath...))
var testWriter strings.Builder
err := tpl.ExecuteWriter(&testWriter, nil)
err := tpl.TestTemplate(nil)
if err != nil {
panic(err)
}

View File

@ -2,14 +2,11 @@ package controllers
import (
"fmt"
"html/template"
"net/http"
"strings"
"git.kealoha.me/lks/lenslocked/models"
"git.kealoha.me/lks/lenslocked/templates"
"git.kealoha.me/lks/lenslocked/views"
"github.com/gorilla/csrf"
)
type Users struct {
@ -23,10 +20,8 @@ type Users struct {
func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) {
var data struct {
Email string
CSRFField template.HTML
}
data.Email = r.FormValue("email")
data.CSRFField = csrf.TemplateField(r)
u.Templates.Signup.Execute(w, r, data)
}
@ -45,10 +40,8 @@ func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) {
func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) {
var data struct {
Email string
CSRFField template.HTML
}
data.Email = r.FormValue("email")
data.CSRFField = csrf.TemplateField(r)
u.Templates.Signin.Execute(w, r, data)
}
func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) {
@ -97,12 +90,11 @@ func Default(user_service *models.UserService, templatePath ...string) Users {
signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml"))
signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml"))
var testWriter strings.Builder
err := signup_tpl.ExecuteWriter(&testWriter, nil)
err := signup_tpl.TestTemplate(nil)
if err != nil {
panic(err)
}
err = signin_tpl.ExecuteWriter(&testWriter, nil)
err = signin_tpl.TestTemplate(nil)
if err != nil {
panic(err)
}

View File

@ -9,7 +9,7 @@
Welcome back!
</h1>
<form action="/signin" method="post">
{{.CSRFField}}
{{csrfField}}
<div class="py-2">
<label for="email" class="text-sm font-semibold text-gray-800">
Email Address

View File

@ -10,7 +10,7 @@
Sign Up!
</h1>
<form action="/signup" method="post">
{{.CSRFField}}
{{csrfField}}
<div>
<label for="signupEmail" class="text-sm font-semibold text-gray-800">Email Address</label>
<input name="email" id="signupEmail" type="email" placeholder="Email address" required autocomplete="email"

View File

@ -3,11 +3,13 @@ package views
import (
"fmt"
"html/template"
"io"
"io/fs"
"log"
"net/http"
"os"
"strings"
"github.com/gorilla/csrf"
)
type Template struct {
@ -15,16 +17,32 @@ type Template struct {
}
func (t Template) Execute(w http.ResponseWriter, r *http.Request, data interface{}) {
tpl, err := t.htmlTpl.Clone()
if err != nil {
log.Printf("Template Clone Error: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML { return csrf.TemplateField(r) },
})
w.Header().Set("Content-Type", "text/html; charset=utf8")
err := t.htmlTpl.Execute(w, data)
err = tpl.Execute(w, data)
if err != nil {
log.Printf("Error executing template: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
func (t Template) ExecuteWriter(w io.Writer, data interface{}) error {
return t.htmlTpl.Execute(w, data)
func (t Template) TestTemplate(data interface{}) error {
var testWriter strings.Builder
tpl, err := t.htmlTpl.Clone()
if err != nil {
return err
}
return tpl.Execute(&testWriter, data)
}
func FromFile(pattern ...string) (Template, error) {
@ -32,7 +50,13 @@ func FromFile(pattern ...string) (Template, error) {
return FromFS(fs, pattern...)
}
func FromFS(fs fs.FS, pattern ...string) (Template, error) {
tpl, err := template.ParseFS(fs, pattern...)
tpl := template.New(pattern[0])
tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML {
return `<div class="hidden">STUB: PLACEHOLDER</div>`
},
})
tpl, err := tpl.ParseFS(fs, pattern...)
if err != nil {
return Template{}, fmt.Errorf("Error parsing template: %v", err)
}