diff --git a/context/users.go b/context/users.go index 5fd85a5..89233e1 100644 --- a/context/users.go +++ b/context/users.go @@ -47,3 +47,25 @@ func (umw UserMiddleware) SetUser(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func (umw UserMiddleware) RequireUserfn(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := User(r.Context()) + if user == nil { + http.Redirect(w, r, "/signin", http.StatusFound) + return + } + next(w, r) + }) +} + +func (umw UserMiddleware) RequireUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := User(r.Context()) + if user == nil { + http.Redirect(w, r, "/signin", http.StatusFound) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/main.go b/main.go index eb0e876..9a76d53 100644 --- a/main.go +++ b/main.go @@ -69,9 +69,13 @@ func main() { sessionService := models.SessionService{DB: db} var usersCtrlr ctrlrs.Users = ctrlrs.Default(&userService, &sessionService) + umw := userctx.UserMiddleware{SS: &sessionService} + r := chi.NewRouter() r.Use(middleware.Logger) + r.Use(csrf.Protect(csrfKey, csrf.Secure(!DEBUG))) + r.Use(umw.SetUser) r.Get("/", ctrlrs.StaticController("home.gohtml", "tailwind.gohtml")) r.Get("/contact", ctrlrs.StaticController("contact.gohtml", "tailwind.gohtml")) @@ -83,12 +87,10 @@ func main() { r.Post("/signin", usersCtrlr.PostSignin) r.Post("/signout", usersCtrlr.GetSignout) - r.Get("/user", usersCtrlr.CurrentUser) - + //r.Get("/user", usersCtrlr.CurrentUser) + r.Get("/user", umw.RequireUserfn(usersCtrlr.CurrentUser)) r.NotFound(notFoundHandler) - umw := userctx.UserMiddleware{SS: &sessionService} - fmt.Println("Starting the server on :3000...") - http.ListenAndServe(":3000", csrf.Protect(csrfKey, csrf.Secure(!DEBUG))(umw.SetUser(r))) + http.ListenAndServe(":3000", r) }