Add gothic example code to internal/auth

This commit is contained in:
Lucas Schumacher 2024-07-20 06:56:45 -04:00
parent 6279037386
commit b4ba7f9736
2 changed files with 87 additions and 5 deletions

View File

@ -1,12 +1,16 @@
package auth package auth
import ( import (
"errors"
"log"
"net/http"
"net/url"
"os"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/markbates/goth" "github.com/markbates/goth"
"github.com/markbates/goth/gothic" "github.com/markbates/goth/gothic"
"github.com/markbates/goth/providers/openidConnect" "github.com/markbates/goth/providers/openidConnect"
"log"
"os"
) )
func NewAuth(sessionStore *sessions.Store) { func NewAuth(sessionStore *sessions.Store) {
@ -27,3 +31,80 @@ func NewAuth(sessionStore *sessions.Store) {
log.Println("No auth session store set. Falling back to default gothic setting.") log.Println("No auth session store set. Falling back to default gothic setting.")
} }
} }
func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) {
providerName, err := gothic.GetProviderName(req)
if err != nil {
log.Println("Error getting provider name: ", err)
return goth.User{}, err
}
provider, err := goth.GetProvider(providerName)
if err != nil {
return goth.User{}, err
}
value, err := gothic.GetFromSession(providerName, req)
if err != nil {
return goth.User{}, err
}
defer Logout(res, req)
sess, err := provider.UnmarshalSession(value)
if err != nil {
return goth.User{}, err
}
err = validateState(req, sess)
if err != nil {
return goth.User{}, err
}
user, err := provider.FetchUser(sess)
if err == nil {
return user, err
}
params := req.URL.Query()
if params.Encode() == "" && req.Method == "POST" {
req.ParseForm()
params = req.Form
}
_, err = sess.Authorize(provider, params)
if err != nil {
return goth.User{}, err
}
err = gothic.StoreInSession(providerName, sess.Marshal(), req, res)
if err != nil {
return goth.User{}, err
}
gu, err := provider.FetchUser(sess)
return gu, err
}
func Logout(res http.ResponseWriter, req *http.Request) error {
return gothic.Logout(res, req)
}
func validateState(req *http.Request, sess goth.Session) error {
rawAuthURL, err := sess.GetAuthURL()
if err != nil {
return err
}
authURL, err := url.Parse(rawAuthURL)
if err != nil {
return err
}
reqState := gothic.GetState(req)
originalState := authURL.Query().Get("state")
if originalState != "" && (originalState != reqState) {
return errors.New("state token mismatch")
}
return nil
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"gothtest/internal/auth"
"encoding/json" "encoding/json"
"log" "log"
"fmt" "fmt"
@ -54,7 +55,7 @@ func (s *Server) getAuthCallbackFunc(w http.ResponseWriter, r *http.Request) {
provider := chi.URLParam(r, "provider") provider := chi.URLParam(r, "provider")
r = r.WithContext(context.WithValue(context.Background(), "provider", provider)) r = r.WithContext(context.WithValue(context.Background(), "provider", provider))
user, err := gothic.CompleteUserAuth(w, r) user, err := auth.CompleteAuthFlow(w, r)
if err != nil { if err != nil {
fmt.Fprintln(w, err) fmt.Fprintln(w, err)
return return