diff --git a/authentication/google/auth.go b/authentication/google/auth.go index 20d5912..08d8ddb 100644 --- a/authentication/google/auth.go +++ b/authentication/google/auth.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "ersteller-lib" + "ersteller-lib/authentication" "github.com/gorilla/sessions" "github.com/labstack/echo/v4" "golang.org/x/oauth2" @@ -28,22 +29,17 @@ type AuthEnv struct { SessionName string } -type UserRepository interface { - GetUserId(email string) (int, error) - Create(email string) (int, error) -} - type Auth struct { Config oauth2.Config repo *GoogleAuthRepository isLocal bool - userRepo UserRepository + userRepo authentication.UserRepository sessionStore *sessions.CookieStore GoogleLoginRoute ersteller_lib.Route environment AuthEnv } -func NewAuth(env AuthEnv, repo *GoogleAuthRepository, userRepo UserRepository, sessionStore *sessions.CookieStore) *Auth { +func NewAuth(env AuthEnv, repo *GoogleAuthRepository, userRepo authentication.UserRepository, sessionStore *sessions.CookieStore) *Auth { config := oauth2.Config{ ClientID: env.GoogleClientId, ClientSecret: env.GoogleClientSecret, diff --git a/authentication/keycloak.go b/authentication/keycloak.go new file mode 100644 index 0000000..91cd282 --- /dev/null +++ b/authentication/keycloak.go @@ -0,0 +1,136 @@ +package authentication + +import ( + "context" + "ersteller-lib" + "github.com/gorilla/sessions" + "github.com/labstack/echo/v4" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/openidConnect" +) + +const OpenIdConnectPath = "/auth/openid-connect" + +type KeycloakEnv struct { + SessionSecret string + BaseUrl string + SessionName string + Keycloak struct { + CLientId string + ClientSecret string + DiscoveryUrl string + } + EmailSessionKey string + UserIdSessionKey string +} + +// https://keycloak-dev.deploy.ersteller.gorlug.de/realms/master/.well-known/openid-configuration +// https://go.dev/play/p/-RtLSPL4Wsj +func RunKeycloakAuth(e *echo.Echo, environment KeycloakEnv, cookieStore *sessions.CookieStore, userRepo UserRepository) { + sessionStore := sessions.NewFilesystemStore("store", []byte(environment.SessionSecret)) + sessionStore.MaxLength(8192) + gothic.Store = sessionStore + // OpenID Connect is based on OpenID Connect Auto Discovery URL (https://openid.net/specs/openid-connect-discovery-1_0-17.html) + // because the OpenID Connect provider initialize itself in the New(), it can return an error which should be handled or ignored + // ignore the error for now + openid, err := openidConnect.New(environment.Keycloak.CLientId, environment.Keycloak.ClientSecret, environment.BaseUrl+"/auth/openid-connect/callback", environment.Keycloak.DiscoveryUrl) + if err != nil { + ersteller_lib.Error("Error while initializing OpenID Connect provider: ", err) + panic(err) + } + if openid != nil { + goth.UseProviders(openid) + } + + e.GET("/auth/openid-connect/callback", func(c echo.Context) error { + + user, err := gothic.CompleteUserAuth(c.Response(), c.Request()) + if err != nil { + return err + } + userId, err := userRepo.GetUserId(user.Email) + if err != nil { + if userId == -1 { + userId, err = createUser(user, userRepo) + if err != nil { + ersteller_lib.LogError("Failed to create user: %v", err) + return err + } + ersteller_lib.LogDebug("Created user with id %d", userId) + } else { + ersteller_lib.LogError("Failed to get user id: %v", err) + return err + } + } + err = saveEmailToSessionStore(c, cookieStore, user.Email, userId, environment) + if err != nil { + return err + } + return c.Redirect(302, "/") + }) + + e.GET("/logout", func(c echo.Context) error { + // Get the session + session, err := cookieStore.Get(c.Request(), environment.SessionName) + if err != nil { + ersteller_lib.LogError("Failed to get session during logout: %v", err) + } else { + // Clear session values + session.Values = make(map[interface{}]interface{}) + // Set MaxAge to -1 to delete the cookie + session.Options.MaxAge = -1 + // Save the session (this will delete the cookie) + err = session.Save(c.Request(), c.Response()) + if err != nil { + ersteller_lib.LogError("Failed to clear session during logout: %v", err) + } + } + + // Also call gothic logout for OpenID Connect cleanup + gothic.Logout(c.Response(), c.Request()) + + // Redirect to login page + return c.Redirect(302, "/login") + }) + + e.GET("/logout/openid-connect", func(c echo.Context) error { + return gothic.Logout(c.Response(), c.Request()) + }) + + e.GET(OpenIdConnectPath, func(c echo.Context) error { + ctx := context.WithValue(c.Request().Context(), gothic.ProviderParamKey, "openid-connect") + request := c.Request().WithContext(ctx) + // try to get the user without re-authenticating + if gothUser, err := gothic.CompleteUserAuth(c.Response(), c.Request()); err == nil { + ersteller_lib.Debug(gothUser) + return nil + } else { + gothic.BeginAuthHandler(c.Response(), request) + return nil + } + }) + +} + +func createUser(gothUser goth.User, repo UserRepository) (int, error) { + return repo.Create(gothUser.Email) +} + +func saveEmailToSessionStore(c echo.Context, sessionStore *sessions.CookieStore, email string, userId int, environment KeycloakEnv) error { + session, err := sessionStore.New(c.Request(), environment.SessionName) + if err != nil { + ersteller_lib.LogError("Failed to create session: %v", err) + return err + } + session.Values = map[interface{}]interface{}{ + environment.EmailSessionKey: email, + environment.UserIdSessionKey: userId, + } + err = session.Save(c.Request(), c.Response()) + if err != nil { + ersteller_lib.LogError("Failed to save session: %v", err) + return err + } + return nil +} diff --git a/authentication/user.go b/authentication/user.go new file mode 100644 index 0000000..a065200 --- /dev/null +++ b/authentication/user.go @@ -0,0 +1,6 @@ +package authentication + +type UserRepository interface { + GetUserId(email string) (int, error) + Create(email string) (int, error) +} diff --git a/go.mod b/go.mod index 169e6f3..3f24a40 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gorilla/sessions v1.4.0 github.com/jackc/pgx/v5 v5.7.5 github.com/labstack/echo/v4 v4.13.4 + github.com/markbates/goth v1.81.0 github.com/mattn/go-sqlite3 v1.14.29 golang.org/x/crypto v0.40.0 golang.org/x/oauth2 v0.30.0 @@ -16,6 +17,9 @@ require ( require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/go-chi/chi/v5 v5.1.0 // indirect + github.com/gorilla/context v1.1.1 // indirect + github.com/gorilla/mux v1.6.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/go.sum b/go.sum index 99b924a..2066f3c 100644 --- a/go.sum +++ b/go.sum @@ -8,10 +8,16 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/doug-martin/goqu/v9 v9.19.0 h1:PD7t1X3tRcUiSdc5TEyOFKujZA5gs3VSA7wxSvBx7qo= github.com/doug-martin/goqu/v9 v9.19.0/go.mod h1:nf0Wc2/hV3gYK9LiyqIrzBEVGlI8qW3GuDCEobC4wBQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= @@ -30,6 +36,8 @@ github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0 github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/lib/pq v1.10.1 h1:6VXZrLU0jHBYyAqrSPa+MgPfnSvTPuMgK+k0o5kVFWo= github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/markbates/goth v1.81.0 h1:XVcCkeGWokynPV7MXvgb8pd2s3r7DS40P7931w6kdnE= +github.com/markbates/goth v1.81.0/go.mod h1:+6z31QyUms84EHmuBY7iuqYSxyoN3njIgg9iCF/lR1k= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -39,8 +47,9 @@ github.com/mattn/go-sqlite3 v1.14.29 h1:1O6nRLJKvsi1H2Sj0Hzdfojwt8GiGKm+LOfLaBFa github.com/mattn/go-sqlite3 v1.14.29/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/user/user.go b/user/user.go new file mode 100644 index 0000000..db9c3d3 --- /dev/null +++ b/user/user.go @@ -0,0 +1,36 @@ +package user + +import ( + "fmt" + "time" +) + +type UserState struct { + SomeValue string +} + +type User struct { + Id int `db:"id"` + Email string `db:"email"` + State UserState `db:"state"` + Admin bool `db:"admin"` + Password string `db:"password"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func (s User) String() string { + return fmt.Sprint("User{ ", + "Id: ", s.Id, ", ", + "Email: ", s.Email, ", ", + "State: ", s.State, ", ", + "Admin: ", s.Admin, ", ", + "Password: ", s.Password, ", ", + "CreatedAt: ", s.CreatedAt, ", ", + "UpdatedAt: ", s.UpdatedAt, ", ", + "}") +} + +func (s User) GetId() int { + return s.Id +} diff --git a/user/user_repository.go b/user/user_repository.go new file mode 100644 index 0000000..8fcd0cb --- /dev/null +++ b/user/user_repository.go @@ -0,0 +1,394 @@ +package user + +import ( + "context" + "encoding/json" + "errors" + "ersteller-lib" + "fmt" + "github.com/doug-martin/goqu/v9" + _ "github.com/doug-martin/goqu/v9/dialect/postgres" + "github.com/doug-martin/goqu/v9/exp" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "strings" + "time" +) + +type UserRepository struct { + connPool *pgxpool.Pool + dialect goqu.DialectWrapper +} + +func NewUserRepository(connPool *pgxpool.Pool) *UserRepository { + return &UserRepository{ + connPool: connPool, + dialect: goqu.Dialect("postgres"), + } +} + +func (r *UserRepository) Create(user User) (int, error) { + sql, args, err := r.dialect.Insert("user"). + Prepared(true). + Rows(goqu.Record{ + + "updated_at": time.Now(), + "email": user.Email, + "state": r.jsonToString(user.State), + "admin": user.Admin, + "password": user.Password, + }). + Returning("id"). + ToSQL() + if err != nil { + ersteller_lib.LogError("error creating create User sql: %v", err) + return -1, err + } + + rows, err := r.connPool.Query(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("error creating User: %v", err) + return -1, err + } + defer rows.Close() + var id int + if rows.Next() { + err = rows.Scan(&id) + if err != nil { + ersteller_lib.LogError("error scanning User: %v", err) + return -1, err + } + } else { + ersteller_lib.Error("User already exists") + return -1, UserAlreadyExistsError{User: user} + } + + return id, nil + +} + +type UserAlreadyExistsError struct { + User User +} + +func (e UserAlreadyExistsError) Error() string { + return fmt.Sprint("User ", e.User, " already exists") +} + +func (r *UserRepository) getSelectColumns() []any { + return []any{"id", "created_at", "updated_at", + "email", "state", "admin", "password", + } +} + +func (r *UserRepository) Read(id int) (User, error) { + ersteller_lib.Debug("Getting User by id ", id) + sql, args, _ := r.dialect.From("user"). + Prepared(true). + Select(r.getSelectColumns()...). + Where(goqu.Ex{ + "id": id, + }). + ToSQL() + + rows, err := r.connPool.Query(context.Background(), sql, args...) + if err != nil { + ersteller_lib.Error("Failed to get User: ", err) + } + defer rows.Close() + if rows.Next() { + item, _, err := r.rowToItem(rows, false) + return item, err + } + return User{}, errors.New("no rows found") +} + +type UserItemScan struct { + User + RowId int + Count int +} + +func (r *UserRepository) rowToItem(rows pgx.Rows, rowId bool) (User, int, error) { + var item UserItemScan + if rowId { + err := rows.Scan( + &item.RowId, + &item.Count, + &item.Id, + &item.CreatedAt, + &item.UpdatedAt, + &item.Email, + &item.State, + &item.Admin, + &item.Password, + ) + if err != nil { + return User{}, -1, err + } + } else { + err := rows.Scan( + &item.Id, + &item.CreatedAt, + &item.UpdatedAt, + &item.Email, + &item.State, + &item.Admin, + &item.Password, + ) + if err != nil { + return User{}, -1, err + } + } + return User{ + Id: item.Id, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + Email: item.Email, + State: item.State, + Admin: item.Admin, + Password: item.Password, + }, item.Count, nil +} + +func (r *UserRepository) Update(user User) error { + sql, args, err := r.dialect.Update("user"). + Prepared(true). + Set(goqu.Record{ + + "updated_at": time.Now(), + "email": user.Email, + "state": r.jsonToString(user.State), + "admin": user.Admin, + "password": user.Password, + }). + Where(goqu.Ex{ + "id": user.Id, + }). + ToSQL() + if err != nil { + ersteller_lib.LogError("error creating update User sql: %v", err) + return err + } + + _, err = r.connPool.Exec(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("error updating User: %v", err) + return err + } + + return nil +} + +func (r *UserRepository) Delete(id int) error { + sql, args, err := r.dialect.Delete("user"). + Prepared(true). + Where(goqu.Ex{ + "id": id, + }). + ToSQL() + if err != nil { + ersteller_lib.LogError("error creating delete User sql: %v", err) + return err + } + + _, err = r.connPool.Exec(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("error deleting User: %v", err) + return err + } + + return nil +} + +type UserField string + +const ( + UserFieldEmail UserField = "email" + UserFieldState UserField = "state" + UserFieldAdmin UserField = "admin" + UserFieldPassword UserField = "password" +) + +type UserEmailFilter struct { + Active bool + Value string +} + +type UserAdminFilter struct { + Active bool + Value bool +} + +type UserPasswordFilter struct { + Active bool + Value string +} + +type UserOrderDirection string + +const ( + UserOrderDirectionAsc UserOrderDirection = "asc" + UserOrderDirectionDesc UserOrderDirection = "desc" +) + +type UserPaginationParams struct { + RowId int + PageSize int + OrderBy UserField + OrderDirection UserOrderDirection + + EmailFilter UserEmailFilter + + AdminFilter UserAdminFilter + + PasswordFilter UserPasswordFilter +} + +func (r *UserRepository) GetPage(params UserPaginationParams) ([]User, int, error) { + var orderByWindow exp.WindowExpression + if params.OrderDirection == UserOrderDirectionAsc { + orderByWindow = goqu.W().OrderBy(goqu.C(string(params.OrderBy)).Asc()) + } else { + orderByWindow = goqu.W().OrderBy(goqu.C(string(params.OrderBy)).Desc()) + } + selectColumns := []any{ + goqu.ROW_NUMBER().Over(orderByWindow).As("row_id"), + goqu.COUNT("*"), + } + selectColumns = append(selectColumns, r.getSelectColumns()...) + whereExpressions := []goqu.Expression{ + goqu.Ex{}, + } + whereExpressions = r.addPageFilters(params, whereExpressions) + var colOrder exp.OrderedExpression + if params.OrderDirection == UserOrderDirectionAsc { + colOrder = goqu.C(string(params.OrderBy)).Asc() + } else { + colOrder = goqu.C(string(params.OrderBy)).Desc() + } + dialect := goqu.Dialect("postgres") + innerFrom := dialect.From("user"). + Prepared(true). + Select(selectColumns...). + Where(whereExpressions...). + Order(colOrder) + + outerFrom := dialect.From(innerFrom). + Prepared(true). + Where(goqu.Ex{"row_id": goqu.Op{"gt": params.RowId}}) + if params.PageSize > 0 { + outerFrom = outerFrom.Limit(uint(params.PageSize)) + } + sql, args, _ := outerFrom.ToSQL() + sql = strings.Replace(sql, "COUNT(*)", "COUNT(*) over()", 1) + + rows, err := r.connPool.Query(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("failed to run sql query: %v", err) + return nil, -1, err + } + defer rows.Close() + results := make([]User, 0) + totalCount := 0 + for rows.Next() { + parsed, count, err := r.rowToItem(rows, true) + if err != nil { + return nil, -1, err + } + totalCount = count + results = append(results, parsed) + } + return results, totalCount, nil +} + +func (r *UserRepository) addPageFilters(params UserPaginationParams, whereExpressions []goqu.Expression) []goqu.Expression { + + if params.EmailFilter.Active { + whereExpressions = append(whereExpressions, goqu.Ex{ + "email": goqu.Op{"ilike": fmt.Sprint("%", params.EmailFilter.Value, "%")}, + }) + } + + if params.AdminFilter.Active { + whereExpressions = append(whereExpressions, goqu.Ex{ + "admin": params.AdminFilter.Value, + }) + } + + if params.PasswordFilter.Active { + whereExpressions = append(whereExpressions, goqu.Ex{ + "password": goqu.Op{"ilike": fmt.Sprint("%", params.PasswordFilter.Value, "%")}, + }) + } + + return whereExpressions +} + +func (r *UserRepository) jsonToString(jsonData any) string { + bytes, err := json.Marshal(jsonData) + if err != nil { + return "{}" + } + return string(bytes) +} + +func (u *UserRepository) DoesUserEmailExist(email string) (bool, error) { + sql, args, _ := u.dialect.From("user"). + Prepared(true). + Select(goqu.COUNT("email")). + Where(goqu.Ex{"email": email}). + ToSQL() + + rows, err := u.connPool.Query(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("failed to run sql query: %v", err) + return false, err + } + defer rows.Close() + if rows.Next() { + var count int + err = rows.Scan(&count) + if err != nil { + return false, err + } + return count == 1, nil + } + return false, nil +} + +func (u *UserRepository) GetUserId(email string) (int, error) { + sql, args, _ := u.dialect.From("user"). + Prepared(true). + Select("id"). + Where(goqu.Ex{"email": email}). + ToSQL() + + rows, err := u.connPool.Query(context.Background(), sql, args...) + if err != nil { + ersteller_lib.LogError("failed to run sql query: %v", err) + return -1, err + } + defer rows.Close() + if rows.Next() { + var id int + err = rows.Scan(&id) + if err != nil { + return -1, err + } + return id, nil + } + return -1, errors.New("did not find user with email " + email) +} + +func (r *UserRepository) VerifyPassword(email string, password string) (bool, int, error) { + userId, err := r.GetUserId(email) + if err != nil { + return false, -1, err + } + user, err := r.Read(userId) + if err != nil { + return false, -1, err + } + return ersteller_lib.VerifyPassword(password, user.Password), userId, nil +}