+ {{ .locale.Tr "auth.oauth.signing-in-with" $.c.OIDCProviderName }} +
+{{ .locale.Tr "auth.oauth.existing-account" }}
++ {{ .locale.Tr "auth.oauth.already-have-account" $.c.OIDCProviderName }} +
+ +diff --git a/internal/auth/oauth/gitea.go b/internal/auth/oauth/gitea.go index 88db9bd..06eed11 100644 --- a/internal/auth/oauth/gitea.go +++ b/internal/auth/oauth/gitea.go @@ -110,6 +110,10 @@ func (p *GiteaCallbackProvider) UpdateUserDB(user *db.User) { user.AvatarURL = field.(string) } +func (p *GiteaCallbackProvider) IsAdmin() bool { + return false +} + func NewGiteaCallbackProvider(user *goth.User) CallbackProvider { return &GiteaCallbackProvider{ User: user, diff --git a/internal/auth/oauth/github.go b/internal/auth/oauth/github.go index 02cd018..81fece4 100644 --- a/internal/auth/oauth/github.go +++ b/internal/auth/oauth/github.go @@ -77,6 +77,10 @@ func (p *GitHubCallbackProvider) UpdateUserDB(user *db.User) { user.AvatarURL = "https://avatars.githubusercontent.com/u/" + p.User.UserID + "?v=4" } +func (p *GitHubCallbackProvider) IsAdmin() bool { + return false +} + func NewGitHubCallbackProvider(user *goth.User) CallbackProvider { return &GitHubCallbackProvider{ User: user, diff --git a/internal/auth/oauth/gitlab.go b/internal/auth/oauth/gitlab.go index 91195f5..42f1135 100644 --- a/internal/auth/oauth/gitlab.go +++ b/internal/auth/oauth/gitlab.go @@ -111,6 +111,10 @@ func (p *GitLabCallbackProvider) UpdateUserDB(user *db.User) { user.AvatarURL = field.(string) } +func (p *GitLabCallbackProvider) IsAdmin() bool { + return false +} + func NewGitLabCallbackProvider(user *goth.User) CallbackProvider { return &GitLabCallbackProvider{ User: user, diff --git a/internal/auth/oauth/openid.go b/internal/auth/oauth/openid.go index ffa17f2..2533b2a 100644 --- a/internal/auth/oauth/openid.go +++ b/internal/auth/oauth/openid.go @@ -3,6 +3,8 @@ package oauth import ( gocontext "context" "errors" + "slices" + "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/markbates/goth/providers/openidConnect" @@ -79,6 +81,31 @@ func (p *OIDCCallbackProvider) UpdateUserDB(user *db.User) { user.AvatarURL = p.User.AvatarURL } +func (p *OIDCCallbackProvider) IsAdmin() bool { + if config.C.OIDCAdminGroup == "" { + return false + } + + groupClaimName := config.C.OIDCGroupClaimName + if groupClaimName == "" { + return false + } + + groups, ok := p.User.RawData[groupClaimName].([]interface{}) + if !ok { + return false + } + + var groupNames []string + for _, group := range groups { + if groupName, ok := group.(string); ok { + groupNames = append(groupNames, groupName) + } + } + + return slices.Contains(groupNames, config.C.OIDCAdminGroup) +} + func NewOIDCCallbackProvider(user *goth.User) CallbackProvider { return &OIDCCallbackProvider{ User: user, diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go index 951e9f4..9b1d511 100644 --- a/internal/auth/oauth/provider.go +++ b/internal/auth/oauth/provider.go @@ -2,15 +2,16 @@ package oauth import ( "fmt" + "io" + "net/http" + "net/url" + "strings" + "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/rs/zerolog/log" "github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/web/context" - "io" - "net/http" - "net/url" - "strings" ) const ( @@ -32,6 +33,7 @@ type CallbackProvider interface { GetProviderUserID(user *db.User) bool GetProviderUserSSHKeys() ([]string, error) UpdateUserDB(user *db.User) + IsAdmin() bool } func DefineProvider(provider string, url string) (Provider, error) { @@ -69,6 +71,29 @@ func CompleteUserAuth(ctx *context.Context) (CallbackProvider, error) { return nil, fmt.Errorf("unsupported provider %s", user.Provider) } +func NewCallbackProviderFromSession(provider string, userID string, nickname string, email string, avatarURL string) (CallbackProvider, error) { + user := &goth.User{ + Provider: provider, + UserID: userID, + NickName: nickname, + Email: email, + AvatarURL: avatarURL, + } + + switch provider { + case GitHubProviderString: + return NewGitHubCallbackProvider(user), nil + case GitLabProviderString: + return NewGitLabCallbackProvider(user), nil + case GiteaProviderString: + return NewGiteaCallbackProvider(user), nil + case OpenIDConnectString: + return NewOIDCCallbackProvider(user), nil + } + + return nil, fmt.Errorf("unsupported provider %s", provider) +} + func urlJoin(base string, elem ...string) string { joined, err := url.JoinPath(base, elem...) if err != nil { diff --git a/internal/db/user.go b/internal/db/user.go index 9bd53c3..e11bf87 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -258,6 +258,11 @@ type UserDTO struct { Password string `form:"password" validate:"required"` } +type OAuthRegisterDTO struct { + Username string `form:"username" validate:"required,max=24,alphanumdash,notreserved"` + Email string `form:"email" validate:"omitempty,email"` +} + func (dto *UserDTO) ToUser() *User { return &User{ Username: dto.Username, diff --git a/internal/i18n/locales/en-US.yml b/internal/i18n/locales/en-US.yml index d641872..9c88f72 100644 --- a/internal/i18n/locales/en-US.yml +++ b/internal/i18n/locales/en-US.yml @@ -200,6 +200,13 @@ auth.password: Password auth.register-instead: Register instead auth.login-instead: Login instead auth.oauth: Continue with %s account +auth.oauth.no-provider: OAuth provider not found +auth.oauth.complete-registration: Complete your registration +auth.oauth.complete-registration-button: Create account +auth.oauth.signing-in-with: Signing in with %s +auth.oauth.cancel: Cancel +auth.oauth.existing-account: Existing account? +auth.oauth.already-have-account: If you already have an Opengist account, login first and link your %s account from your settings. auth.mfa: Multi-factor authentication auth.mfa.passkey: Passkey auth.mfa.passkeys: Passkeys @@ -241,7 +248,7 @@ error.signup-disabled: Signing up is disabled error.signup-disabled-form: Signing up via registration form is disabled error.login-disabled-form: Logging in via login form is disabled error.complete-oauth-login: "Cannot complete user auth: %s" -error.oauth-unsupported: Unsupported provider +error.oauth-unsupported: Unsupported OAuth2 provider error.cannot-bind-data: Cannot bind data error.invalid-number: Invalid number error.invalid-character-unescaped: Invalid character unescaped @@ -343,6 +350,8 @@ flash.auth.user-sshkeys-not-created: Could not create ssh key flash.auth.must-be-logged-in: You must be logged in to access gists flash.auth.passkey-registred: Passkey %s registered flash.auth.passkey-deleted: Passkey deleted +flash.auth.oauth-session-expired: OAuth2 session expired, please try again +flash.auth.oauth-already-linked: This %s account is already linked to another user flash.gist.visibility-changed: Gist visibility has been changed flash.gist.deleted: Gist has been deleted diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 5b2649a..91721c6 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -59,7 +59,7 @@ func validateReservedKeywords(fl validator.FieldLevel) bool { name := fl.Field().String() restrictedNames := map[string]struct{}{} - for _, restrictedName := range []string{"assets", "register", "login", "logout", "settings", "admin-panel", "all", "search", "init", "healthcheck", "preview", "metrics", "mfa", "webauthn"} { + for _, restrictedName := range []string{"assets", "register", "login", "logout", "settings", "admin-panel", "all", "search", "init", "healthcheck", "preview", "metrics", "mfa", "webauthn", "oauth"} { restrictedNames[restrictedName] = struct{}{} } diff --git a/internal/web/handlers/auth/oauth.go b/internal/web/handlers/auth/oauth.go index fd1831f..116b6fb 100644 --- a/internal/web/handlers/auth/oauth.go +++ b/internal/web/handlers/auth/oauth.go @@ -4,16 +4,15 @@ import ( "crypto/md5" "errors" "fmt" - "slices" "strings" "github.com/rs/zerolog/log" "github.com/thomiceli/opengist/internal/auth/oauth" "github.com/thomiceli/opengist/internal/config" "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" "github.com/thomiceli/opengist/internal/web/context" - "golang.org/x/text/cases" - "golang.org/x/text/language" "gorm.io/gorm" ) @@ -48,7 +47,8 @@ func Oauth(ctx *context.Context) error { provider, err := oauth.DefineProvider(providerStr, opengistUrl) if err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil) + ctx.AddFlash(ctx.Tr("error.oauth-unsupported"), "error") + return ctx.Redirect(302, "/login") } if err = provider.RegisterProvider(); err != nil { @@ -62,28 +62,37 @@ func Oauth(ctx *context.Context) error { func OauthCallback(ctx *context.Context) error { provider, err := oauth.CompleteUserAuth(ctx) if err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.complete-oauth-login", err.Error()), err) + ctx.AddFlash(ctx.Tr("auth.oauth.no-provider"), "error") + return ctx.Redirect(302, "/login") } currUser := ctx.User + user := provider.GetProviderUser() + // if user is logged in, link account to user and update its avatar URL if currUser != nil { + // check if this OAuth account is already linked to another user + if existingUser, err := db.GetUserByProvider(user.UserID, provider.GetProvider()); err == nil && existingUser != nil { + ctx.AddFlash(ctx.Tr("flash.auth.oauth-already-linked", config.C.OIDCProviderName), "error") + return ctx.RedirectTo("/settings") + } + provider.UpdateUserDB(currUser) if err = currUser.Update(); err != nil { - return ctx.ErrorRes(500, "Cannot update user "+cases.Title(language.English).String(provider.GetProvider())+" id", err) + return ctx.ErrorRes(500, "Cannot update user "+config.C.OIDCProviderName+" id", err) } - ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", cases.Title(language.English).String(provider.GetProvider())), "success") + ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", config.C.OIDCProviderName), "success") return ctx.RedirectTo("/settings") } - user := provider.GetProviderUser() userDB, err := db.GetUserByProvider(user.UserID, provider.GetProvider()) - // if user is not in database, create it + // if user is not in database, redirect to OAuth registration page if err != nil { if ctx.GetData("DisableSignup") == true { - return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) + ctx.AddFlash(ctx.Tr("error.signup-disabled"), "error") + return ctx.Redirect(302, "/login") } if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -94,74 +103,25 @@ func OauthCallback(ctx *context.Context) error { user.NickName = strings.Split(user.Email, "@")[0] } - userDB = &db.User{ - Username: user.NickName, - Email: user.Email, - MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))), - } + sess := ctx.GetSession() + sess.Values["oauthProvider"] = provider.GetProvider() + sess.Values["oauthUserID"] = user.UserID + sess.Values["oauthNickname"] = user.NickName + sess.Values["oauthEmail"] = user.Email + sess.Values["oauthAvatarURL"] = user.AvatarURL + sess.Values["oauthIsAdmin"] = provider.IsAdmin() - // set provider id and avatar URL - provider.UpdateUserDB(userDB) + sess.Options.MaxAge = 10 * 60 // 10 minutes + ctx.SaveSession(sess) - if err = userDB.Create(); err != nil { - if db.IsUniqueConstraintViolation(err) { - ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") - return ctx.RedirectTo("/login") - } - - return ctx.ErrorRes(500, "Cannot create user", err) - } - - // if oidc admin group is not configured set first user as admin - if config.C.OIDCAdminGroup == "" && userDB.ID == 1 { - if err = userDB.SetAdmin(); err != nil { - return ctx.ErrorRes(500, "Cannot set user admin", err) - } - } - - keys, err := provider.GetProviderUserSSHKeys() - if err != nil { - ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error") - log.Error().Err(err).Msg("Could not get user keys") - } else { - for _, key := range keys { - sshKey := db.SSHKey{ - Title: "Added from " + user.Provider, - Content: key, - User: *userDB, - } - - if err = sshKey.Create(); err != nil { - ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error") - log.Error().Err(err).Msg("Could not create ssh key") - } - } - } + return ctx.RedirectTo("/oauth/register") } - // update is admin status from oidc group - if config.C.OIDCAdminGroup != "" { - groupClaimName := config.C.OIDCGroupClaimName - if groupClaimName == "" { - log.Error().Msg("No OIDC group claim name configured") - } else if groups, ok := user.RawData[groupClaimName].([]interface{}); ok { - var groupNames []string - for _, group := range groups { - if groupName, ok := group.(string); ok { - groupNames = append(groupNames, groupName) - } - } - isOIDCAdmin := slices.Contains(groupNames, config.C.OIDCAdminGroup) - log.Debug().Bool("isOIDCAdmin", isOIDCAdmin).Str("user", user.Name).Msg("User is in admin group") - - if userDB.IsAdmin != isOIDCAdmin { - userDB.IsAdmin = isOIDCAdmin - if err = userDB.Update(); err != nil { - return ctx.ErrorRes(500, "Cannot set user admin", err) - } - } - } else { - log.Error().Msg("No groups found in user data") + // promote user to admin from oidc group + if !userDB.IsAdmin && provider.IsAdmin() { + userDB.IsAdmin = true + if err = userDB.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) } } @@ -173,6 +133,150 @@ func OauthCallback(ctx *context.Context) error { return ctx.RedirectTo("/") } +func OauthRegister(ctx *context.Context) error { + if ctx.GetData("DisableSignup") == true { + ctx.AddFlash(ctx.Tr("error.signup-disabled"), "error") + return ctx.Redirect(302, "/login") + } + + sess := ctx.GetSession() + + ctx.SetData("title", ctx.TrH("auth.oauth.complete-registration")) + ctx.SetData("htmlTitle", ctx.TrH("auth.oauth.complete-registration")) + ctx.SetData("oauthProvider", config.C.OIDCProviderName) + ctx.SetData("oauthNickname", sess.Values["oauthNickname"]) + ctx.SetData("oauthEmail", sess.Values["oauthEmail"]) + ctx.SetData("oauthAvatarURL", sess.Values["oauthAvatarURL"]) + + return ctx.Html("oauth_register.html") +} + +func ProcessOauthRegister(ctx *context.Context) error { + if ctx.GetData("DisableSignup") == true { + ctx.AddFlash(ctx.Tr("error.signup-disabled"), "error") + return ctx.Redirect(302, "/login") + } + + sess := ctx.GetSession() + + providerStr := sess.Values["oauthProvider"].(string) + oauthUserID := sess.Values["oauthUserID"].(string) + + setOauthRegisterData := func(dto *db.OAuthRegisterDTO) { + ctx.SetData("title", ctx.TrH("auth.oauth.complete-registration")) + ctx.SetData("htmlTitle", ctx.TrH("auth.oauth.complete-registration")) + ctx.SetData("oauthProvider", config.C.OIDCProviderName) + if dto != nil { + ctx.SetData("oauthNickname", dto.Username) + ctx.SetData("oauthEmail", dto.Email) + } else { + ctx.SetData("oauthNickname", sess.Values["oauthNickname"]) + ctx.SetData("oauthEmail", sess.Values["oauthEmail"]) + } + ctx.SetData("oauthAvatarURL", sess.Values["oauthAvatarURL"]) + } + + // Bind and validate form data + dto := new(db.OAuthRegisterDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + setOauthRegisterData(dto) + return ctx.Html("oauth_register.html") + } + + if exists, err := db.UserExists(dto.Username); err != nil || exists { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + setOauthRegisterData(dto) + return ctx.Html("oauth_register.html") + } + + // Check if OAuth account is already linked to another user (race condition protection) + if existingUser, err := db.GetUserByProvider(oauthUserID, providerStr); err == nil && existingUser != nil { + ctx.AddFlash(ctx.Tr("flash.auth.oauth-already-linked", config.C.OIDCProviderName), "error") + setOauthRegisterData(dto) + return ctx.Html("oauth_register.html") + } + + userDB := &db.User{ + Username: dto.Username, + Email: dto.Email, + } + if dto.Email != "" { + userDB.MD5Hash = fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(dto.Email))))) + } + + nickname := "" + if n, ok := sess.Values["oauthNickname"].(string); ok { + nickname = n + } + avatarURL := "" + if av, ok := sess.Values["oauthAvatarURL"].(string); ok { + avatarURL = av + } + + callbackProvider, err := oauth.NewCallbackProviderFromSession(providerStr, oauthUserID, nickname, dto.Email, avatarURL) + if err != nil { + return ctx.ErrorRes(500, "Cannot create provider", err) + } + callbackProvider.UpdateUserDB(userDB) + + if err := userDB.Create(); err != nil { + if db.IsUniqueConstraintViolation(err) { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + setOauthRegisterData(dto) + return ctx.Html("oauth_register.html") + } + return ctx.ErrorRes(500, "Cannot create user", err) + } + + if config.C.OIDCAdminGroup == "" && userDB.ID == 1 { + if err := userDB.SetAdmin(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) + } + } + + if isAdmin, ok := sess.Values["oauthIsAdmin"].(bool); ok && isAdmin { + userDB.IsAdmin = true + _ = userDB.Update() + } + + keys, err := callbackProvider.GetProviderUserSSHKeys() + if err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error") + log.Error().Err(err).Msg("Could not get user keys") + } else { + for _, key := range keys { + sshKey := db.SSHKey{ + Title: "Added from " + providerStr, + Content: key, + User: *userDB, + } + if err = sshKey.Create(); err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error") + log.Error().Err(err).Msg("Could not create ssh key") + } + } + } + + delete(sess.Values, "oauthProvider") + delete(sess.Values, "oauthUserID") + delete(sess.Values, "oauthNickname") + delete(sess.Values, "oauthEmail") + delete(sess.Values, "oauthAvatarURL") + delete(sess.Values, "oauthIsAdmin") + + sess.Values["user"] = userDB.ID + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + ctx.SaveSession(sess) + ctx.DeleteCsrfCookie() + + return ctx.RedirectTo("/") +} + func OauthUnlink(ctx *context.Context) error { providerStr := ctx.Param("provider") provider, err := oauth.DefineProvider(ctx.Param("provider"), "") @@ -184,10 +288,10 @@ func OauthUnlink(ctx *context.Context) error { if provider.UserHasProvider(currUser) { if err := currUser.DeleteProviderID(providerStr); err != nil { - return ctx.ErrorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(providerStr), err) + return ctx.ErrorRes(500, "Cannot unlink account from "+config.C.OIDCProviderName, err) } - ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", cases.Title(language.English).String(providerStr)), "success") + ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", config.C.OIDCProviderName), "success") return ctx.RedirectTo("/settings") } diff --git a/internal/web/server/middlewares.go b/internal/web/server/middlewares.go index 258bb1a..8f79b02 100644 --- a/internal/web/server/middlewares.go +++ b/internal/web/server/middlewares.go @@ -199,6 +199,17 @@ func inMFASession(next Handler) Handler { } } +func inOAuthRegisterSession(next Handler) Handler { + return func(ctx *context.Context) error { + sess := ctx.GetSession() + _, ok := sess.Values["oauthProvider"].(string) + if !ok { + return ctx.RedirectTo("/login") + } + return next(ctx) + } +} + func makeCheckRequireLogin(isSingleGistAccess bool) Middleware { return func(next Handler) Handler { return func(ctx *context.Context) error { diff --git a/internal/web/server/router.go b/internal/web/server/router.go index 99a44ce..635f183 100644 --- a/internal/web/server/router.go +++ b/internal/web/server/router.go @@ -38,6 +38,8 @@ func (s *Server) registerRoutes() { r.GET("/login", auth.Login) r.POST("/login", auth.ProcessLogin) r.GET("/logout", auth.Logout) + r.GET("/oauth/register", auth.OauthRegister, inOAuthRegisterSession) + r.POST("/oauth/register", auth.ProcessOauthRegister, inOAuthRegisterSession) r.GET("/oauth/:provider", auth.Oauth) r.GET("/oauth/:provider/callback", auth.OauthCallback) r.GET("/oauth/:provider/unlink", auth.OauthUnlink, logged) diff --git a/templates/pages/oauth_register.html b/templates/pages/oauth_register.html new file mode 100644 index 0000000..d7d688d --- /dev/null +++ b/templates/pages/oauth_register.html @@ -0,0 +1,85 @@ +{{ template "header" .}} +
+ {{ .locale.Tr "auth.oauth.signing-in-with" $.c.OIDCProviderName }} +
+{{ .locale.Tr "auth.oauth.existing-account" }}
++ {{ .locale.Tr "auth.oauth.already-have-account" $.c.OIDCProviderName }} +
+ +