package db import ( "fmt" "github.com/rs/zerolog/log" "gorm.io/gorm" ) type MigrationVersion struct { ID uint `gorm:"primaryKey"` Version uint } func applyMigrations(dbInfo *databaseInfo) error { switch dbInfo.Type { case SQLite, PostgreSQL, MySQL: return applyAllMigrations(dbInfo.Type) default: return fmt.Errorf("unknown database type: %s", dbInfo.Type) } } func applyAllMigrations(dbType databaseType) error { if err := db.AutoMigrate(&MigrationVersion{}); err != nil { log.Fatal().Err(err).Msg("Error creating migration version table") return err } var currentVersion MigrationVersion db.First(¤tVersion) migrations := []struct { Version uint DBTypes []databaseType // nil = all types Func func() error }{ {1, []databaseType{SQLite}, v1_modifyConstraintToSSHKeys}, {2, []databaseType{SQLite}, v2_lowercaseEmails}, {3, nil, v3_normalizedColumns}, } for _, m := range migrations { if m.Version <= currentVersion.Version { continue } // Skip migrations not intended for this DB type if len(m.DBTypes) > 0 { applicable := false for _, t := range m.DBTypes { if t == dbType { applicable = true break } } if !applicable { // Advance version so we don't retry on next startup currentVersion.Version = m.Version db.Save(¤tVersion) continue } } tx := db.Begin() if err := tx.Error; err != nil { log.Fatal().Err(err).Msg("Error starting transaction") return err } if err := m.Func(); err != nil { tx.Rollback() log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version)) return err } if err := tx.Commit().Error; err != nil { log.Fatal().Err(err).Msg(fmt.Sprintf("Error committing migration %d:", m.Version)) return err } currentVersion.Version = m.Version db.Save(¤tVersion) log.Info().Msg(fmt.Sprintf("Migration %d applied successfully", m.Version)) } return nil } // Modify the constraint on the ssh_keys table to use ON DELETE CASCADE func v1_modifyConstraintToSSHKeys() error { createSQL := ` CREATE TABLE ssh_keys_temp ( id integer primary key, title text, content text, sha text, created_at integer, last_used_at integer, user_id integer constraint fk_users_ssh_keys references users(id) on update cascade on delete cascade ); ` if err := db.Exec(createSQL).Error; err != nil { return err } // Copy data from the old table to the new table copySQL := `INSERT INTO ssh_keys_temp SELECT * FROM ssh_keys;` if err := db.Exec(copySQL).Error; err != nil { return err } // Drop the old table dropSQL := `DROP TABLE ssh_keys;` if err := db.Exec(dropSQL).Error; err != nil { return err } // Rename the new table to the original table name renameSQL := `ALTER TABLE ssh_keys_temp RENAME TO ssh_keys;` return db.Exec(renameSQL).Error } func v2_lowercaseEmails() error { // Copy the lowercase emails into the new column copySQL := `UPDATE users SET email = lower(email);` return db.Exec(copySQL).Error } func v3_normalizedColumns() error { if err := db.Model(&User{}).Where("username_normalized = '' OR username_normalized IS NULL"). Updates(map[string]interface{}{"username_normalized": gorm.Expr("LOWER(username)")}).Error; err != nil { return err } return db.Model(&Gist{}).Where("url_normalized = '' OR url_normalized IS NULL"). Updates(map[string]interface{}{"url_normalized": gorm.Expr("LOWER(url)")}).Error }