Use a custom template function for csrf protection

This commit is contained in:
2024-08-13 06:58:02 -04:00
parent 8bc58eedbe
commit 4cf50a7d81
5 changed files with 37 additions and 24 deletions

View File

@@ -3,11 +3,13 @@ package views
import (
"fmt"
"html/template"
"io"
"io/fs"
"log"
"net/http"
"os"
"strings"
"github.com/gorilla/csrf"
)
type Template struct {
@@ -15,16 +17,32 @@ type Template struct {
}
func (t Template) Execute(w http.ResponseWriter, r *http.Request, data interface{}) {
tpl, err := t.htmlTpl.Clone()
if err != nil {
log.Printf("Template Clone Error: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML { return csrf.TemplateField(r) },
})
w.Header().Set("Content-Type", "text/html; charset=utf8")
err := t.htmlTpl.Execute(w, data)
err = tpl.Execute(w, data)
if err != nil {
log.Printf("Error executing template: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
func (t Template) ExecuteWriter(w io.Writer, data interface{}) error {
return t.htmlTpl.Execute(w, data)
func (t Template) TestTemplate(data interface{}) error {
var testWriter strings.Builder
tpl, err := t.htmlTpl.Clone()
if err != nil {
return err
}
return tpl.Execute(&testWriter, data)
}
func FromFile(pattern ...string) (Template, error) {
@@ -32,7 +50,13 @@ func FromFile(pattern ...string) (Template, error) {
return FromFS(fs, pattern...)
}
func FromFS(fs fs.FS, pattern ...string) (Template, error) {
tpl, err := template.ParseFS(fs, pattern...)
tpl := template.New(pattern[0])
tpl = tpl.Funcs(template.FuncMap{
"csrfField": func() template.HTML {
return `<div class="hidden">STUB: PLACEHOLDER</div>`
},
})
tpl, err := tpl.ParseFS(fs, pattern...)
if err != nil {
return Template{}, fmt.Errorf("Error parsing template: %v", err)
}