Files
Gists/internal/db/migration.go
2026-03-03 15:28:49 +08:00

140 lines
3.4 KiB
Go

package db
import (
"fmt"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
type MigrationVersion struct {
ID uint `gorm:"primaryKey"`
Version uint
}
func applyMigrations(dbInfo *databaseInfo) error {
switch dbInfo.Type {
case SQLite, PostgreSQL, MySQL:
return applyAllMigrations(dbInfo.Type)
default:
return fmt.Errorf("unknown database type: %s", dbInfo.Type)
}
}
func applyAllMigrations(dbType databaseType) error {
if err := db.AutoMigrate(&MigrationVersion{}); err != nil {
log.Fatal().Err(err).Msg("Error creating migration version table")
return err
}
var currentVersion MigrationVersion
db.First(&currentVersion)
migrations := []struct {
Version uint
DBTypes []databaseType // nil = all types
Func func() error
}{
{1, []databaseType{SQLite}, v1_modifyConstraintToSSHKeys},
{2, []databaseType{SQLite}, v2_lowercaseEmails},
{3, nil, v3_normalizedColumns},
}
for _, m := range migrations {
if m.Version <= currentVersion.Version {
continue
}
// Skip migrations not intended for this DB type
if len(m.DBTypes) > 0 {
applicable := false
for _, t := range m.DBTypes {
if t == dbType {
applicable = true
break
}
}
if !applicable {
// Advance version so we don't retry on next startup
currentVersion.Version = m.Version
db.Save(&currentVersion)
continue
}
}
tx := db.Begin()
if err := tx.Error; err != nil {
log.Fatal().Err(err).Msg("Error starting transaction")
return err
}
if err := m.Func(); err != nil {
tx.Rollback()
log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version))
return err
}
if err := tx.Commit().Error; err != nil {
log.Fatal().Err(err).Msg(fmt.Sprintf("Error committing migration %d:", m.Version))
return err
}
currentVersion.Version = m.Version
db.Save(&currentVersion)
log.Info().Msg(fmt.Sprintf("Migration %d applied successfully", m.Version))
}
return nil
}
// Modify the constraint on the ssh_keys table to use ON DELETE CASCADE
func v1_modifyConstraintToSSHKeys() error {
createSQL := `
CREATE TABLE ssh_keys_temp (
id integer primary key,
title text,
content text,
sha text,
created_at integer,
last_used_at integer,
user_id integer
constraint fk_users_ssh_keys references users(id) on update cascade on delete cascade
);
`
if err := db.Exec(createSQL).Error; err != nil {
return err
}
// Copy data from the old table to the new table
copySQL := `INSERT INTO ssh_keys_temp SELECT * FROM ssh_keys;`
if err := db.Exec(copySQL).Error; err != nil {
return err
}
// Drop the old table
dropSQL := `DROP TABLE ssh_keys;`
if err := db.Exec(dropSQL).Error; err != nil {
return err
}
// Rename the new table to the original table name
renameSQL := `ALTER TABLE ssh_keys_temp RENAME TO ssh_keys;`
return db.Exec(renameSQL).Error
}
func v2_lowercaseEmails() error {
// Copy the lowercase emails into the new column
copySQL := `UPDATE users SET email = lower(email);`
return db.Exec(copySQL).Error
}
func v3_normalizedColumns() error {
if err := db.Model(&User{}).Where("username_normalized = '' OR username_normalized IS NULL").
Updates(map[string]interface{}{"username_normalized": gorm.Expr("LOWER(username)")}).Error; err != nil {
return err
}
return db.Model(&Gist{}).Where("url_normalized = '' OR url_normalized IS NULL").
Updates(map[string]interface{}{"url_normalized": gorm.Expr("LOWER(url)")}).Error
}