diff --git a/controllers/static.go b/controllers/static.go index 7a591a5..7b92d6d 100644 --- a/controllers/static.go +++ b/controllers/static.go @@ -4,23 +4,21 @@ import ( "git.kealoha.me/lks/lenslocked/templates" "git.kealoha.me/lks/lenslocked/views" "net/http" - "strings" ) type Template interface { - Execute(w http.ResponseWriter, data interface{}) + Execute(w http.ResponseWriter, r *http.Request, data interface{}) } -func StaticTemplate(templatePath ...string) http.HandlerFunc { +func StaticController(templatePath ...string) http.HandlerFunc { tpl := views.Must(views.FromFS(templates.FS, templatePath...)) - var testWriter strings.Builder - err := tpl.ExecuteWriter(&testWriter, nil) + err := tpl.TestTemplate(nil) if err != nil { panic(err) } - return func(w http.ResponseWriter, r *http.Request) { tpl.Execute(w, nil) } + return func(w http.ResponseWriter, r *http.Request) { tpl.Execute(w, r, nil) } } func FAQ(templatePath ...string) http.HandlerFunc { @@ -40,13 +38,12 @@ func FAQ(templatePath ...string) http.HandlerFunc { tpl := views.Must(views.FromFS(templates.FS, templatePath...)) - var testWriter strings.Builder - err := tpl.ExecuteWriter(&testWriter, nil) + err := tpl.TestTemplate(nil) if err != nil { panic(err) } return func(w http.ResponseWriter, r *http.Request) { - tpl.Execute(w, questions) + tpl.Execute(w, r, questions) } } diff --git a/controllers/users.go b/controllers/users.go index 44facc0..882bcd6 100644 --- a/controllers/users.go +++ b/controllers/users.go @@ -3,40 +3,102 @@ package controllers import ( "fmt" "net/http" - "strings" + "git.kealoha.me/lks/lenslocked/models" "git.kealoha.me/lks/lenslocked/templates" "git.kealoha.me/lks/lenslocked/views" ) type Users struct { Templates struct { - New Template + Signup Template + Signin Template } + UserService *models.UserService } -func (u Users) New(w http.ResponseWriter, r *http.Request) { +func (u Users) GetSignup(w http.ResponseWriter, r *http.Request) { var data struct { Email string } data.Email = r.FormValue("email") - u.Templates.New.Execute(w, data) -} -func (u Users) Create(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "TODO! ", r.FormValue("email")) - + u.Templates.Signup.Execute(w, r, data) } -func FromStaticTemplate(templatePath ...string) Users { - tpl := views.Must(views.FromFS(templates.FS, templatePath...)) +func (u Users) PostSignup(w http.ResponseWriter, r *http.Request) { + email := r.FormValue("email") + password := r.FormValue("password") + user, err := u.UserService.Create(email, password) + if err != nil { + fmt.Println(err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + fmt.Fprintf(w, "User created: %+v", user) +} - var testWriter strings.Builder - err := tpl.ExecuteWriter(&testWriter, nil) +func (u Users) GetSignin(w http.ResponseWriter, r *http.Request) { + var data struct { + Email string + } + data.Email = r.FormValue("email") + u.Templates.Signin.Execute(w, r, data) +} +func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) { + var data struct { + Email string + Password string + } + data.Email = r.FormValue("email") + data.Password = r.FormValue("password") + user, err := u.UserService.Authenticate(data.Email, data.Password) + if err != nil { + fmt.Println(err) + http.Error(w, "Something went wrong.", http.StatusInternalServerError) + return + } + + // Bad cookie + cookie := http.Cookie{ + Name: "bad", + Value: user.Email, + Path: "/", + HttpOnly: true, + } + http.SetCookie(w, &cookie) + + fmt.Fprintf(w, "User authenticated: %+v", user) +} + +func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { + email, err := r.Cookie("bad") + if err != nil { + fmt.Fprint(w, "The bad cookie could not be read.") + return + } + fmt.Fprintf(w, "Bad cookie: %s\n", email.Value) +} + +func WithTemplates(user_service *models.UserService, signup Template, signin Template) Users { + u := Users{} + u.Templates.Signup = signup + u.Templates.Signin = signin + u.UserService = user_service + return u +} + +func Default(user_service *models.UserService, templatePath ...string) Users { + signup_tpl := views.Must(views.FromFS(templates.FS, "signup.gohtml", "tailwind.gohtml")) + signin_tpl := views.Must(views.FromFS(templates.FS, "signin.gohtml", "tailwind.gohtml")) + + err := signup_tpl.TestTemplate(nil) + if err != nil { + panic(err) + } + err = signin_tpl.TestTemplate(nil) if err != nil { panic(err) } - u := Users{} - u.Templates.New = tpl - return u + return WithTemplates(user_service, signup_tpl, signin_tpl) } diff --git a/go.mod b/go.mod index 231a4ff..f914105 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,12 @@ go 1.22.5 require ( github.com/go-chi/chi/v5 v5.1.0 github.com/jackc/pgx/v4 v4.18.3 - golang.org/x/crypto v0.20.0 + golang.org/x/crypto v0.26.0 ) require ( + github.com/gorilla/csrf v1.7.2 + github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect @@ -16,5 +18,5 @@ require ( github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.17.0 // indirect ) diff --git a/go.sum b/go.sum index 81b718f..53b29bf 100644 --- a/go.sum +++ b/go.sum @@ -16,7 +16,13 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= +github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -126,8 +132,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg= -golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -156,8 +162,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/main.go b/main.go index 347616c..d665c32 100644 --- a/main.go +++ b/main.go @@ -1,30 +1,68 @@ package main import ( + "database/sql" "fmt" "net/http" + "os" ctrlrs "git.kealoha.me/lks/lenslocked/controllers" + "git.kealoha.me/lks/lenslocked/models" "github.com/go-chi/chi/v5" + "github.com/gorilla/csrf" + "github.com/go-chi/chi/v5/middleware" + _ "github.com/jackc/pgx/v4/stdlib" ) +const DEBUG bool = true + func notFoundHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf8") w.WriteHeader(http.StatusNotFound) fmt.Fprint(w, "404 page not found") } +func ConnectDB() *sql.DB { + db, err := sql.Open("pgx", os.Getenv("LENSLOCKED_DB_STRING")) + if err != nil { + panic(fmt.Sprint("Error connecting to database: %w", err)) + } + err = db.Ping() + if err != nil { + panic(fmt.Sprint("Error connecting to database: %w", err)) + } + return db +} + func main() { - var usersCtrlr ctrlrs.Users = ctrlrs.FromStaticTemplate("signup.gohtml", "tailwind.gohtml") + csrfKey := []byte(os.Getenv("LENSLOCKED_CSRF_KEY")) + if len(csrfKey) < 32 { + panic("Error: bad csrf protection key") + } + + db := ConnectDB() + defer db.Close() + + userService := models.UserService{DB: db} + var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService) + r := chi.NewRouter() + r.Use(middleware.Logger) - r.Get("/", ctrlrs.StaticTemplate("home.gohtml", "tailwind.gohtml")) - r.Get("/contact", ctrlrs.StaticTemplate("contact.gohtml", "tailwind.gohtml")) + + r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml")) + r.Get("/contact", ctrlrs.StaticController("contact.gohtml", "tailwind.gohtml")) r.Get("/faq", ctrlrs.FAQ("faq.gohtml", "tailwind.gohtml")) - r.Get("/signup", usersCtrlr.New) - r.Post("/signup", usersCtrlr.Create) + + r.Get("/signup", usersCtrlr.GetSignup) + r.Post("/signup", usersCtrlr.PostSignup) + r.Get("/signin", usersCtrlr.GetSignin) + r.Post("/signin", usersCtrlr.PostSignin) + + r.Get("/user", usersCtrlr.CurrentUser) + r.NotFound(notFoundHandler) fmt.Println("Starting the server on :3000...") - http.ListenAndServe(":3000", r) + http.ListenAndServe(":3000", csrf.Protect(csrfKey, csrf.Secure(!DEBUG))(r)) } diff --git a/models/sql/users.sql b/models/sql/users.sql new file mode 100644 index 0000000..edcb85f --- /dev/null +++ b/models/sql/users.sql @@ -0,0 +1,5 @@ +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL +); diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..286dd91 --- /dev/null +++ b/models/user.go @@ -0,0 +1,64 @@ +package models + +import ( + "database/sql" + "fmt" + "strings" + + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int + Email string + PasswordHash string +} + +type UserService struct { + DB *sql.DB +} + +func (us *UserService) Create(email, password string) (*User, error) { + email = strings.ToLower(email) + + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("create user: %w", err) + } + passwordHash := string(hashedBytes) + + user := User{ + Email: email, + PasswordHash: passwordHash, + } + row := us.DB.QueryRow(` + INSERT INTO users (email, password_hash) + VALUES ($1, $2) RETURNING id + `, email, passwordHash) + err = row.Scan(&user.ID) + if err != nil { + return nil, fmt.Errorf("create user: %w", err) + } + return &user, nil +} + +func (us UserService) Authenticate(email, password string) (*User, error) { + user := User{ + Email: strings.ToLower(email), + } + + row := us.DB.QueryRow(` + SELECT id, password_hash + FROM users WHERE email=$1 + `, email) + err := row.Scan(&user.ID, &user.PasswordHash) + if err != nil { + return nil, fmt.Errorf("authenticate: %w", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) + if err != nil { + return nil, fmt.Errorf("authenticate: %w", err) + } + return &user, nil +} diff --git a/templates/signin.gohtml b/templates/signin.gohtml new file mode 100644 index 0000000..a999ed0 --- /dev/null +++ b/templates/signin.gohtml @@ -0,0 +1,65 @@ + + + {{template "head" .}} + + {{template "header".}} +
+
+

+ Welcome back! +

+
+ {{csrfField}} +
+ + +
+
+ + +
+
+ +
+
+

+ Need an account? + Sign up +

+

+ Forgot your password? +

+
+
+
+
+ {{template "footer" .}} + + diff --git a/templates/signup.gohtml b/templates/signup.gohtml index 894d22a..106f9aa 100644 --- a/templates/signup.gohtml +++ b/templates/signup.gohtml @@ -10,6 +10,7 @@ Sign Up!
+ {{csrfField}}
FAQ
- Sign in + Sign in Sign up
diff --git a/views/template.go b/views/template.go index ad2e4db..40c4fc5 100644 --- a/views/template.go +++ b/views/template.go @@ -3,40 +3,60 @@ package views import ( "fmt" "html/template" - "io" "io/fs" "log" "net/http" + "os" + "strings" + + "github.com/gorilla/csrf" ) type Template struct { htmlTpl *template.Template } -func (t Template) Execute(w http.ResponseWriter, data interface{}) { +func (t Template) Execute(w http.ResponseWriter, r *http.Request, data interface{}) { + tpl, err := t.htmlTpl.Clone() + if err != nil { + log.Printf("Template Clone Error: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + tpl = tpl.Funcs(template.FuncMap{ + "csrfField": func() template.HTML { return csrf.TemplateField(r) }, + }) + w.Header().Set("Content-Type", "text/html; charset=utf8") - err := t.htmlTpl.Execute(w, data) + err = tpl.Execute(w, data) if err != nil { log.Printf("Error executing template: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } -func (t Template) ExecuteWriter(w io.Writer, data interface{}) error { - return t.htmlTpl.Execute(w, data) +func (t Template) TestTemplate(data interface{}) error { + var testWriter strings.Builder + tpl, err := t.htmlTpl.Clone() + if err != nil { + return err + } + return tpl.Execute(&testWriter, data) } -func FromFile(filepath string) (Template, error) { - tpl, err := template.ParseFiles(filepath) - if err != nil { - return Template{}, fmt.Errorf("Error parsing template: %v", err) - } - return Template{ - htmlTpl: tpl, - }, nil +func FromFile(pattern ...string) (Template, error) { + fs := os.DirFS(".") + return FromFS(fs, pattern...) } func FromFS(fs fs.FS, pattern ...string) (Template, error) { - tpl, err := template.ParseFS(fs, pattern...) + tpl := template.New(pattern[0]) + tpl = tpl.Funcs(template.FuncMap{ + "csrfField": func() template.HTML { + return `` + }, + }) + tpl, err := tpl.ParseFS(fs, pattern...) if err != nil { return Template{}, fmt.Errorf("Error parsing template: %v", err) }