Add gothic example code to internal/auth
This commit is contained in:
parent
6279037386
commit
b4ba7f9736
@ -1,12 +1,16 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/markbates/goth"
|
"github.com/markbates/goth"
|
||||||
"github.com/markbates/goth/gothic"
|
"github.com/markbates/goth/gothic"
|
||||||
"github.com/markbates/goth/providers/openidConnect"
|
"github.com/markbates/goth/providers/openidConnect"
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAuth(sessionStore *sessions.Store) {
|
func NewAuth(sessionStore *sessions.Store) {
|
||||||
@ -24,6 +28,83 @@ func NewAuth(sessionStore *sessions.Store) {
|
|||||||
if sessionStore != nil {
|
if sessionStore != nil {
|
||||||
gothic.Store = *sessionStore
|
gothic.Store = *sessionStore
|
||||||
} else {
|
} else {
|
||||||
log.Println("No auth session store set. Falling back to default gothic setting.")
|
log.Println("No auth session store set. Falling back to default gothic setting.")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) {
|
||||||
|
providerName, err := gothic.GetProviderName(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error getting provider name: ", err)
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := goth.GetProvider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := gothic.GetFromSession(providerName, req)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
defer Logout(res, req)
|
||||||
|
sess, err := provider.UnmarshalSession(value)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateState(req, sess)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := provider.FetchUser(sess)
|
||||||
|
if err == nil {
|
||||||
|
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := req.URL.Query()
|
||||||
|
if params.Encode() == "" && req.Method == "POST" {
|
||||||
|
req.ParseForm()
|
||||||
|
params = req.Form
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = sess.Authorize(provider, params)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = gothic.StoreInSession(providerName, sess.Marshal(), req, res)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gu, err := provider.FetchUser(sess)
|
||||||
|
return gu, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Logout(res http.ResponseWriter, req *http.Request) error {
|
||||||
|
return gothic.Logout(res, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateState(req *http.Request, sess goth.Session) error {
|
||||||
|
rawAuthURL, err := sess.GetAuthURL()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL, err := url.Parse(rawAuthURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqState := gothic.GetState(req)
|
||||||
|
|
||||||
|
originalState := authURL.Query().Get("state")
|
||||||
|
if originalState != "" && (originalState != reqState) {
|
||||||
|
return errors.New("state token mismatch")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gothtest/internal/auth"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -54,7 +55,7 @@ func (s *Server) getAuthCallbackFunc(w http.ResponseWriter, r *http.Request) {
|
|||||||
provider := chi.URLParam(r, "provider")
|
provider := chi.URLParam(r, "provider")
|
||||||
r = r.WithContext(context.WithValue(context.Background(), "provider", provider))
|
r = r.WithContext(context.WithValue(context.Background(), "provider", provider))
|
||||||
|
|
||||||
user, err := gothic.CompleteUserAuth(w, r)
|
user, err := auth.CompleteAuthFlow(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(w, err)
|
fmt.Fprintln(w, err)
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user