111 lines
2.4 KiB
Go

package auth
import (
"errors"
"log"
"net/http"
"net/url"
"os"
"github.com/gorilla/sessions"
"github.com/markbates/goth"
"github.com/markbates/goth/gothic"
"github.com/markbates/goth/providers/openidConnect"
)
func NewAuth(sessionStore *sessions.Store) {
oidcId := os.Getenv("OIDC_ID")
oidcSec := os.Getenv("OIDC_SECRET")
oidcDiscUrl := os.Getenv("OIDC_DISC_URL")
oidcRedirectUrl := "http://localhost:3003/auth/openid-connect/callback"
openidConnect, err := openidConnect.New(oidcId, oidcSec, oidcRedirectUrl, oidcDiscUrl)
if openidConnect == nil || err != nil {
log.Fatal("Error setting up oidc")
}
goth.UseProviders(openidConnect)
if sessionStore != nil {
gothic.Store = *sessionStore
} else {
log.Println("No auth session store set. Falling back to default gothic setting.")
}
}
func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) {
providerName, err := gothic.GetProviderName(req)
if err != nil {
log.Println("Error getting provider name: ", err)
return goth.User{}, err
}
provider, err := goth.GetProvider(providerName)
if err != nil {
return goth.User{}, err
}
value, err := gothic.GetFromSession(providerName, req)
if err != nil {
return goth.User{}, err
}
defer Logout(res, req)
sess, err := provider.UnmarshalSession(value)
if err != nil {
return goth.User{}, err
}
err = validateState(req, sess)
if err != nil {
return goth.User{}, err
}
user, err := provider.FetchUser(sess)
if err == nil {
return user, err
}
params := req.URL.Query()
if params.Encode() == "" && req.Method == "POST" {
req.ParseForm()
params = req.Form
}
_, err = sess.Authorize(provider, params)
if err != nil {
return goth.User{}, err
}
err = gothic.StoreInSession(providerName, sess.Marshal(), req, res)
if err != nil {
return goth.User{}, err
}
gu, err := provider.FetchUser(sess)
return gu, err
}
func Logout(res http.ResponseWriter, req *http.Request) error {
return gothic.Logout(res, req)
}
func validateState(req *http.Request, sess goth.Session) error {
rawAuthURL, err := sess.GetAuthURL()
if err != nil {
return err
}
authURL, err := url.Parse(rawAuthURL)
if err != nil {
return err
}
reqState := gothic.GetState(req)
originalState := authURL.Query().Get("state")
if originalState != "" && (originalState != reqState) {
return errors.New("state token mismatch")
}
return nil
}