package auth import ( "errors" "log" "net/http" "net/url" "os" "github.com/gorilla/sessions" "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/markbates/goth/providers/openidConnect" ) func NewAuth(sessionStore *sessions.Store) { oidcId := os.Getenv("OIDC_ID") oidcSec := os.Getenv("OIDC_SECRET") oidcDiscUrl := os.Getenv("OIDC_DISC_URL") oidcRedirectUrl := "http://localhost:3003/auth/openid-connect/callback" openidConnect, err := openidConnect.New(oidcId, oidcSec, oidcRedirectUrl, oidcDiscUrl) if openidConnect == nil || err != nil { log.Fatal("Error setting up oidc") } goth.UseProviders(openidConnect) if sessionStore != nil { gothic.Store = *sessionStore } else { log.Println("No auth session store set. Falling back to default gothic setting.") } } func CompleteAuthFlow(res http.ResponseWriter, req *http.Request) (goth.User, error) { providerName, err := gothic.GetProviderName(req) if err != nil { log.Println("Error getting provider name: ", err) return goth.User{}, err } provider, err := goth.GetProvider(providerName) if err != nil { return goth.User{}, err } value, err := gothic.GetFromSession(providerName, req) if err != nil { return goth.User{}, err } defer Logout(res, req) sess, err := provider.UnmarshalSession(value) if err != nil { return goth.User{}, err } err = validateState(req, sess) if err != nil { return goth.User{}, err } user, err := provider.FetchUser(sess) if err == nil { return user, err } params := req.URL.Query() if params.Encode() == "" && req.Method == "POST" { req.ParseForm() params = req.Form } _, err = sess.Authorize(provider, params) if err != nil { return goth.User{}, err } err = gothic.StoreInSession(providerName, sess.Marshal(), req, res) if err != nil { return goth.User{}, err } gu, err := provider.FetchUser(sess) return gu, err } func Logout(res http.ResponseWriter, req *http.Request) error { return gothic.Logout(res, req) } func validateState(req *http.Request, sess goth.Session) error { rawAuthURL, err := sess.GetAuthURL() if err != nil { return err } authURL, err := url.Parse(rawAuthURL) if err != nil { return err } reqState := gothic.GetState(req) originalState := authURL.Query().Get("state") if originalState != "" && (originalState != reqState) { return errors.New("state token mismatch") } return nil }