diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 7e2630d..57c094a 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,12 +1,16 @@ package auth import ( + "errors" + "log" + "net/http" + "net/url" + "os" + "github.com/gorilla/sessions" "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/markbates/goth/providers/openidConnect" - "log" - "os" ) func NewAuth(sessionStore *sessions.Store) { @@ -24,6 +28,83 @@ func NewAuth(sessionStore *sessions.Store) { if sessionStore != nil { gothic.Store = *sessionStore } else { - log.Println("No auth session store set. Falling back to default gothic setting.") - } + log.Println("No auth session store set. Falling back to default gothic setting.") + } +} + +func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) { + providerName, err := gothic.GetProviderName(req) + if err != nil { + log.Println("Error getting provider name: ", err) + return goth.User{}, err + } + + provider, err := goth.GetProvider(providerName) + if err != nil { + return goth.User{}, err + } + + value, err := gothic.GetFromSession(providerName, req) + if err != nil { + return goth.User{}, err + } + defer Logout(res, req) + sess, err := provider.UnmarshalSession(value) + if err != nil { + return goth.User{}, err + } + + err = validateState(req, sess) + if err != nil { + return goth.User{}, err + } + + user, err := provider.FetchUser(sess) + if err == nil { + + return user, err + } + + params := req.URL.Query() + if params.Encode() == "" && req.Method == "POST" { + req.ParseForm() + params = req.Form + } + + _, err = sess.Authorize(provider, params) + if err != nil { + return goth.User{}, err + } + + err = gothic.StoreInSession(providerName, sess.Marshal(), req, res) + if err != nil { + return goth.User{}, err + } + + gu, err := provider.FetchUser(sess) + return gu, err +} + +func Logout(res http.ResponseWriter, req *http.Request) error { + return gothic.Logout(res, req) +} + +func validateState(req *http.Request, sess goth.Session) error { + rawAuthURL, err := sess.GetAuthURL() + if err != nil { + return err + } + + authURL, err := url.Parse(rawAuthURL) + if err != nil { + return err + } + + reqState := gothic.GetState(req) + + originalState := authURL.Query().Get("state") + if originalState != "" && (originalState != reqState) { + return errors.New("state token mismatch") + } + return nil } diff --git a/internal/server/routes.go b/internal/server/routes.go index 8bbc3d1..8b7c8b2 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "gothtest/internal/auth" "encoding/json" "log" "fmt" @@ -54,7 +55,7 @@ func (s *Server) getAuthCallbackFunc(w http.ResponseWriter, r *http.Request) { provider := chi.URLParam(r, "provider") r = r.WithContext(context.WithValue(context.Background(), "provider", provider)) - user, err := gothic.CompleteUserAuth(w, r) + user, err := auth.CompleteAuthFlow(w, r) if err != nil { fmt.Fprintln(w, err) return