1
0

Make gists username/urls case insensitive in URLS (#641)

Signed-off-by: Thomas Miceli <tho.miceli@gmail.com>
هذا الالتزام موجود في:
Thomas Miceli
2026-03-03 14:28:49 +07:00
ملتزم من قبل GitHub
الأصل 4ab38f24c8
التزام d796eeba98
5 ملفات معدلة مع 134 إضافات و55 حذوفات

عرض الملف

@@ -71,6 +71,7 @@ type Gist struct {
Uuid string Uuid string
Title string Title string
URL string URL string
URLNormalized string
Preview string Preview string
PreviewFilename string PreviewFilename string
PreviewMimeType string PreviewMimeType string
@@ -98,6 +99,11 @@ type Like struct {
CreatedAt int64 CreatedAt int64
} }
func (gist *Gist) BeforeSave(_ *gorm.DB) error {
gist.URLNormalized = strings.ToLower(gist.URL)
return nil
}
func (gist *Gist) BeforeDelete(tx *gorm.DB) error { func (gist *Gist) BeforeDelete(tx *gorm.DB) error {
// Decrement fork counter if the gist was forked // Decrement fork counter if the gist was forked
err := tx.Model(&Gist{}). err := tx.Model(&Gist{}).
@@ -110,7 +116,8 @@ func (gist *Gist) BeforeDelete(tx *gorm.DB) error {
func GetGist(user string, gistUuid string) (*Gist, error) { func GetGist(user string, gistUuid string) (*Gist, error) {
gist := new(Gist) gist := new(Gist)
err := db.Preload("User").Preload("Forked.User").Preload("Topics"). err := db.Preload("User").Preload("Forked.User").Preload("Topics").
Where("(gists.uuid like ? OR gists.url = ?) AND users.username like ?", gistUuid+"%", gistUuid, user). Where("(gists.uuid LIKE ? OR gists.url_normalized = ?) AND users.username_normalized = ?",
strings.ToLower(gistUuid)+"%", strings.ToLower(gistUuid), strings.ToLower(user)).
Joins("join users on gists.user_id = users.id"). Joins("join users on gists.user_id = users.id").
First(&gist).Error First(&gist).Error

عرض الملف

@@ -2,7 +2,9 @@ package db
import ( import (
"fmt" "fmt"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm"
) )
type MigrationVersion struct { type MigrationVersion struct {
@@ -12,60 +14,74 @@ type MigrationVersion struct {
func applyMigrations(dbInfo *databaseInfo) error { func applyMigrations(dbInfo *databaseInfo) error {
switch dbInfo.Type { switch dbInfo.Type {
case SQLite: case SQLite, PostgreSQL, MySQL:
return applySqliteMigrations() return applyAllMigrations(dbInfo.Type)
case PostgreSQL, MySQL:
return nil
default: default:
return fmt.Errorf("unknown database type: %s", dbInfo.Type) return fmt.Errorf("unknown database type: %s", dbInfo.Type)
} }
} }
func applySqliteMigrations() error { func applyAllMigrations(dbType databaseType) error {
// Create migration table if it doesn't exist
if err := db.AutoMigrate(&MigrationVersion{}); err != nil { if err := db.AutoMigrate(&MigrationVersion{}); err != nil {
log.Fatal().Err(err).Msg("Error creating migration version table") log.Fatal().Err(err).Msg("Error creating migration version table")
return err return err
} }
// Get the current migration version
var currentVersion MigrationVersion var currentVersion MigrationVersion
db.First(&currentVersion) db.First(&currentVersion)
// Define migrations
migrations := []struct { migrations := []struct {
Version uint Version uint
DBTypes []databaseType // nil = all types
Func func() error Func func() error
}{ }{
{1, v1_modifyConstraintToSSHKeys}, {1, []databaseType{SQLite}, v1_modifyConstraintToSSHKeys},
{2, v2_lowercaseEmails}, {2, []databaseType{SQLite}, v2_lowercaseEmails},
// Add more migrations here as needed {3, nil, v3_normalizedColumns},
} }
// Apply migrations
for _, m := range migrations { for _, m := range migrations {
if m.Version > currentVersion.Version { if m.Version <= currentVersion.Version {
tx := db.Begin() continue
if err := tx.Error; err != nil { }
log.Fatal().Err(err).Msg("Error starting transaction")
return err
}
if err := m.Func(); err != nil { // Skip migrations not intended for this DB type
log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version)) if len(m.DBTypes) > 0 {
tx.Rollback() applicable := false
return err for _, t := range m.DBTypes {
} else { if t == dbType {
if err = tx.Commit().Error; err != nil { applicable = true
log.Fatal().Err(err).Msg(fmt.Sprintf("Error committing migration %d:", m.Version)) break
return err
} }
}
if !applicable {
// Advance version so we don't retry on next startup
currentVersion.Version = m.Version currentVersion.Version = m.Version
db.Save(&currentVersion) db.Save(&currentVersion)
log.Info().Msg(fmt.Sprintf("Migration %d applied successfully", m.Version)) continue
} }
} }
tx := db.Begin()
if err := tx.Error; err != nil {
log.Fatal().Err(err).Msg("Error starting transaction")
return err
}
if err := m.Func(); err != nil {
tx.Rollback()
log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version))
return err
}
if err := tx.Commit().Error; err != nil {
log.Fatal().Err(err).Msg(fmt.Sprintf("Error committing migration %d:", m.Version))
return err
}
currentVersion.Version = m.Version
db.Save(&currentVersion)
log.Info().Msg(fmt.Sprintf("Migration %d applied successfully", m.Version))
} }
return nil return nil
@@ -112,3 +128,12 @@ func v2_lowercaseEmails() error {
copySQL := `UPDATE users SET email = lower(email);` copySQL := `UPDATE users SET email = lower(email);`
return db.Exec(copySQL).Error return db.Exec(copySQL).Error
} }
func v3_normalizedColumns() error {
if err := db.Model(&User{}).Where("username_normalized = '' OR username_normalized IS NULL").
Updates(map[string]interface{}{"username_normalized": gorm.Expr("LOWER(username)")}).Error; err != nil {
return err
}
return db.Model(&Gist{}).Where("url_normalized = '' OR url_normalized IS NULL").
Updates(map[string]interface{}{"url_normalized": gorm.Expr("LOWER(url)")}).Error
}

عرض الملف

@@ -2,24 +2,27 @@ package db
import ( import (
"encoding/json" "encoding/json"
"strings"
"github.com/thomiceli/opengist/internal/git" "github.com/thomiceli/opengist/internal/git"
"gorm.io/gorm" "gorm.io/gorm"
) )
type User struct { type User struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
Username string `gorm:"uniqueIndex,size:191"` Username string `gorm:"uniqueIndex,size:191"`
Password string UsernameNormalized string `gorm:"index"`
IsAdmin bool Password string
CreatedAt int64 IsAdmin bool
Email string CreatedAt int64
MD5Hash string // for gravatar, if no Email is specified, the value is random Email string
AvatarURL string MD5Hash string // for gravatar, if no Email is specified, the value is random
GithubID string AvatarURL string
GitlabID string GithubID string
GiteaID string GitlabID string
OIDCID string `gorm:"column:oidc_id"` GiteaID string
StylePreferences string OIDCID string `gorm:"column:oidc_id"`
StylePreferences string
Gists []Gist `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"` Gists []Gist `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"`
SSHKeys []SSHKey `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"` SSHKeys []SSHKey `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"`
@@ -28,6 +31,11 @@ type User struct {
AccessTokens []AccessToken `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"` AccessTokens []AccessToken `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;foreignKey:UserID"`
} }
func (user *User) BeforeSave(_ *gorm.DB) error {
user.UsernameNormalized = strings.ToLower(user.Username)
return nil
}
func (user *User) BeforeDelete(tx *gorm.DB) error { func (user *User) BeforeDelete(tx *gorm.DB) error {
// Decrement likes counter using derived table // Decrement likes counter using derived table
err := tx.Exec(` err := tx.Exec(`
@@ -93,7 +101,7 @@ func (user *User) BeforeDelete(tx *gorm.DB) error {
func UserExists(username string) (bool, error) { func UserExists(username string) (bool, error) {
var count int64 var count int64
err := db.Model(&User{}).Where("username like ?", username).Count(&count).Error err := db.Model(&User{}).Where("username_normalized = ?", strings.ToLower(username)).Count(&count).Error
return count > 0, err return count > 0, err
} }
@@ -111,7 +119,7 @@ func GetAllUsers(offset int) ([]*User, error) {
func GetUserByUsername(username string) (*User, error) { func GetUserByUsername(username string) (*User, error) {
user := new(User) user := new(User)
err := db. err := db.
Where("username like ?", username). Where("username_normalized = ?", strings.ToLower(username)).
First(&user).Error First(&user).Error
return user, err return user, err
} }

عرض الملف

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/url" "net/url"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -291,3 +292,36 @@ func TestGistAccess(t *testing.T) {
}) })
} }
} }
func TestGetGistCaseInsensitive(t *testing.T) {
s := webtest.Setup(t)
defer webtest.Teardown(t)
s.Register(t, "THOmas")
s.Login(t, "THOmas")
s.Request(t, "POST", "/", url.Values{
"title": {"Test"},
"name": {"file.txt"},
"content": {"hello world"},
"url": {"my-GIST"},
"private": {"0"},
}, 302)
gist, err := db.GetGistByID("1")
require.NoError(t, err)
s.Logout()
t.Run("URL", func(t *testing.T) {
s.Request(t, "GET", "/thomas/my-gist", nil, 200)
s.Request(t, "GET", "/THOMAS/MY-GIST", nil, 200)
s.Request(t, "GET", "/thomas/MY-GIST", nil, 200)
s.Request(t, "GET", "/THOMAS/my-gist", nil, 200)
})
t.Run("UUID", func(t *testing.T) {
s.Request(t, "GET", "/thomas/"+strings.ToLower(gist.Uuid), nil, 200)
s.Request(t, "GET", "/THOMAS/"+strings.ToUpper(gist.Uuid), nil, 200)
})
}

عرض الملف

@@ -3,16 +3,17 @@ package settings
import ( import (
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/thomiceli/opengist/internal/config" "github.com/thomiceli/opengist/internal/config"
"github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/db"
"github.com/thomiceli/opengist/internal/git" "github.com/thomiceli/opengist/internal/git"
"github.com/thomiceli/opengist/internal/i18n" "github.com/thomiceli/opengist/internal/i18n"
"github.com/thomiceli/opengist/internal/validator" "github.com/thomiceli/opengist/internal/validator"
"github.com/thomiceli/opengist/internal/web/context" "github.com/thomiceli/opengist/internal/web/context"
"os"
"path/filepath"
"strings"
"time"
) )
func EmailProcess(ctx *context.Context) error { func EmailProcess(ctx *context.Context) error {
@@ -61,18 +62,22 @@ func UsernameProcess(ctx *context.Context) error {
return ctx.RedirectTo("/settings") return ctx.RedirectTo("/settings")
} }
if exists, err := db.UserExists(dto.Username); err != nil || exists { if !strings.EqualFold(dto.Username, user.Username) {
ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") if exists, err := db.UserExists(dto.Username); err != nil || exists {
return ctx.RedirectTo("/settings") ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error")
return ctx.RedirectTo("/settings")
}
} }
sourceDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(user.Username)) sourceDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(user.Username))
destinationDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(dto.Username)) destinationDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(dto.Username))
if _, err := os.Stat(sourceDir); !os.IsNotExist(err) { if sourceDir != destinationDir {
err := os.Rename(sourceDir, destinationDir) if _, err := os.Stat(sourceDir); !os.IsNotExist(err) {
if err != nil { err := os.Rename(sourceDir, destinationDir)
return ctx.ErrorRes(500, "Cannot rename user directory", err) if err != nil {
return ctx.ErrorRes(500, "Cannot rename user directory", err)
}
} }
} }