111 lines
2.4 KiB
Go
111 lines
2.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"github.com/markbates/goth"
|
|
"github.com/markbates/goth/gothic"
|
|
"github.com/markbates/goth/providers/openidConnect"
|
|
)
|
|
|
|
func NewAuth(sessionStore *sessions.Store) {
|
|
oidcId := os.Getenv("OIDC_ID")
|
|
oidcSec := os.Getenv("OIDC_SECRET")
|
|
oidcDiscUrl := os.Getenv("OIDC_DISC_URL")
|
|
oidcRedirectUrl := "http://localhost:3003/auth/openid-connect/callback"
|
|
|
|
openidConnect, err := openidConnect.New(oidcId, oidcSec, oidcRedirectUrl, oidcDiscUrl)
|
|
if openidConnect == nil || err != nil {
|
|
log.Fatal("Error setting up oidc")
|
|
}
|
|
goth.UseProviders(openidConnect)
|
|
|
|
if sessionStore != nil {
|
|
gothic.Store = *sessionStore
|
|
} else {
|
|
log.Println("No auth session store set. Falling back to default gothic setting.")
|
|
}
|
|
}
|
|
|
|
func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) {
|
|
providerName, err := gothic.GetProviderName(req)
|
|
if err != nil {
|
|
log.Println("Error getting provider name: ", err)
|
|
return goth.User{}, err
|
|
}
|
|
|
|
provider, err := goth.GetProvider(providerName)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
|
|
value, err := gothic.GetFromSession(providerName, req)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
defer Logout(res, req)
|
|
sess, err := provider.UnmarshalSession(value)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
|
|
err = validateState(req, sess)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
|
|
user, err := provider.FetchUser(sess)
|
|
if err == nil {
|
|
|
|
return user, err
|
|
}
|
|
|
|
params := req.URL.Query()
|
|
if params.Encode() == "" && req.Method == "POST" {
|
|
req.ParseForm()
|
|
params = req.Form
|
|
}
|
|
|
|
_, err = sess.Authorize(provider, params)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
|
|
err = gothic.StoreInSession(providerName, sess.Marshal(), req, res)
|
|
if err != nil {
|
|
return goth.User{}, err
|
|
}
|
|
|
|
gu, err := provider.FetchUser(sess)
|
|
return gu, err
|
|
}
|
|
|
|
func Logout(res http.ResponseWriter, req *http.Request) error {
|
|
return gothic.Logout(res, req)
|
|
}
|
|
|
|
func validateState(req *http.Request, sess goth.Session) error {
|
|
rawAuthURL, err := sess.GetAuthURL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
authURL, err := url.Parse(rawAuthURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
reqState := gothic.GetState(req)
|
|
|
|
originalState := authURL.Query().Get("state")
|
|
if originalState != "" && (originalState != reqState) {
|
|
return errors.New("state token mismatch")
|
|
}
|
|
return nil
|
|
}
|