From e32aa9ca6c4a21df28be062ef069cdce3baa3aff Mon Sep 17 00:00:00 2001 From: Lucas Schumacher Date: Wed, 28 Aug 2024 22:54:58 -0400 Subject: [PATCH] Use middleware for user session --- context/users.go | 49 ++++++++++++++++++++++++++++++++++++++++++++ controllers/users.go | 15 +++++--------- main.go | 6 +++++- 3 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 context/users.go diff --git a/context/users.go b/context/users.go new file mode 100644 index 0000000..5fd85a5 --- /dev/null +++ b/context/users.go @@ -0,0 +1,49 @@ +package userctx + +import ( + "context" + "net/http" + + "git.kealoha.me/lks/lenslocked/models" +) + +type key string + +const userKey key = "User" + +func WithUser(ctx context.Context, user *models.User) context.Context { + return context.WithValue(ctx, userKey, user) +} + +func User(ctx context.Context) *models.User { + val := ctx.Value(userKey) + user, ok := val.(*models.User) + if !ok { + return nil + } + return user +} + +type UserMiddleware struct { + SS *models.SessionService +} + +func (umw UserMiddleware) SetUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seshCookie, err := r.Cookie("session") + if err != nil { + next.ServeHTTP(w, r) + return + } + user, err := umw.SS.User(seshCookie.Value) + if err != nil { + next.ServeHTTP(w, r) + return + } + + ctx := r.Context() + ctx = WithUser(ctx, user) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) +} diff --git a/controllers/users.go b/controllers/users.go index a4219f5..7e9314b 100644 --- a/controllers/users.go +++ b/controllers/users.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" + userctx "git.kealoha.me/lks/lenslocked/context" "git.kealoha.me/lks/lenslocked/models" "git.kealoha.me/lks/lenslocked/templates" "git.kealoha.me/lks/lenslocked/views" @@ -89,7 +90,8 @@ func (u Users) PostSignin(w http.ResponseWriter, r *http.Request) { } http.SetCookie(w, &cookie) - http.Redirect(w, r, "/user", http.StatusFound) + fmt.Fprintf(w, "Current user: %s\n", user.Email) + //http.Redirect(w, r, "/user", http.StatusFound) } func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) { @@ -113,15 +115,8 @@ func (u Users) GetSignout(w http.ResponseWriter, r *http.Request) { } func (u Users) CurrentUser(w http.ResponseWriter, r *http.Request) { - seshCookie, err := r.Cookie("session") - if err != nil { - fmt.Println(err) - http.Redirect(w, r, "/signin", http.StatusFound) - return - } - user, err := u.SessionService.User(seshCookie.Value) - if err != nil { - fmt.Println(err) + user := userctx.User(r.Context()) + if user == nil { http.Redirect(w, r, "/signin", http.StatusFound) return } diff --git a/main.go b/main.go index 8589182..eb0e876 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" + userctx "git.kealoha.me/lks/lenslocked/context" ctrlrs "git.kealoha.me/lks/lenslocked/controllers" "git.kealoha.me/lks/lenslocked/migrations" "git.kealoha.me/lks/lenslocked/models" @@ -85,6 +86,9 @@ func main() { r.Get("/user", usersCtrlr.CurrentUser) r.NotFound(notFoundHandler) + + umw := userctx.UserMiddleware{SS: &sessionService} + fmt.Println("Starting the server on :3000...") - http.ListenAndServe(":3000", csrf.Protect(csrfKey, csrf.Secure(!DEBUG))(r)) + http.ListenAndServe(":3000", csrf.Protect(csrfKey, csrf.Secure(!DEBUG))(umw.SetUser(r))) }