Change download count atomically to prevent race condition (#223)

This commit is contained in:
Marc Ole Bulling
2024-12-12 17:02:31 +01:00
committed by GitHub
parent 1611127bfa
commit ed8d476f3c
11 changed files with 249 additions and 54 deletions
@@ -200,6 +200,11 @@ func DeleteMetaData(id string) {
db.DeleteMetaData(id)
}
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
func IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
db.IncreaseDownloadCount(id, decreaseRemainingDownloads)
}
// Session Section
// GetSession returns the session with the given ID or false if not a valid ID
@@ -158,6 +158,24 @@ func TestMetaData(t *testing.T) {
runAllTypesCompareOutput(t, func() any { return GetAllMetaDataIds() }, []string{})
runAllTypesCompareOutput(t, func() any { return GetAllMetadata() }, map[string]models.File{})
runAllTypesCompareTwoOutputs(t, func() (any, any) { return GetMetaDataById("testid") }, models.File{}, false)
increasedDownload := file
increasedDownload.DownloadCount = increasedDownload.DownloadCount + 1
runAllTypesCompareTwoOutputs(t, func() (any, any) {
SaveMetaData(file)
IncreaseDownloadCount(file.Id, false)
return GetMetaDataById(file.Id)
}, increasedDownload, true)
increasedDownload.DownloadCount = increasedDownload.DownloadCount + 1
increasedDownload.DownloadsRemaining = increasedDownload.DownloadsRemaining - 1
runAllTypesCompareTwoOutputs(t, func() (any, any) {
IncreaseDownloadCount(file.Id, true)
return GetMetaDataById(file.Id)
}, increasedDownload, true)
runAllTypesNoOutput(t, func() { DeleteMetaData(file.Id) })
}
func TestUpgrade(t *testing.T) {
@@ -72,6 +72,8 @@ type Database interface {
SaveMetaData(file models.File)
// DeleteMetaData deletes information about a file
DeleteMetaData(id string)
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
IncreaseDownloadCount(id string, decreaseRemainingDownloads bool)
// GetSession returns the session with the given ID or false if not a valid ID
GetSession(id string) (models.Session, bool)
@@ -1,6 +1,8 @@
package redis
import (
"bytes"
"encoding/gob"
"errors"
"fmt"
"github.com/forceu/gokapi/internal/helper"
@@ -17,7 +19,7 @@ type DatabaseProvider struct {
dbPrefix string
}
const DatabaseSchemeVersion = 2
const DatabaseSchemeVersion = 3
// New returns an instance
func New(dbConfig models.DbConnection) (DatabaseProvider, error) {
@@ -91,8 +93,35 @@ func newPool(config models.DbConnection) *redigo.Pool {
// Upgrade migrates the DB to a new Gokapi version, if required
func (p DatabaseProvider) Upgrade(currentDbVersion int) {
// Currently no upgrade necessary
return
// < 1.9.6
if currentDbVersion < 3 {
allFiles := getAllLegacyMetaDataAndDelete(p)
for _, file := range allFiles {
p.SaveMetaData(file)
}
}
}
func getAllLegacyMetaDataAndDelete(p DatabaseProvider) map[string]models.File {
result := make(map[string]models.File)
allMetaData := p.getAllValuesWithPrefix(prefixMetaData)
for _, metaData := range allMetaData {
content, err := redigo.Bytes(metaData, nil)
helper.Check(err)
file := legacyDbToMetaData(content)
result[file.Id] = file
p.deleteKey(prefixMetaData + file.Id)
}
return result
}
func legacyDbToMetaData(input []byte) models.File {
var result models.File
buf := bytes.NewBuffer(input)
dec := gob.NewDecoder(buf)
err := dec.Decode(&result)
helper.Check(err)
return result
}
const keyDbVersion = "dbversion"
@@ -260,6 +289,20 @@ func (p DatabaseProvider) deleteKey(id string) {
helper.Check(err)
}
func (p DatabaseProvider) increaseHashmapIntField(id string, field string) {
conn := p.pool.Get()
defer conn.Close()
_, err := conn.Do("HINCRBY", p.dbPrefix+id, field, 1)
helper.Check(err)
}
func (p DatabaseProvider) decreaseHashmapIntField(id string, field string) {
conn := p.pool.Get()
defer conn.Close()
_, err := conn.Do("HINCRBY", p.dbPrefix+id, field, -1)
helper.Check(err)
}
func (p DatabaseProvider) runEval(cmd string) {
conn := p.pool.Get()
defer conn.Close()
@@ -241,6 +241,49 @@ func TestApiKeys(t *testing.T) {
test.IsEqualBool(t, key.LastUsed == 10, true)
}
func TestDatabaseProvider_IncreaseDownloadCount(t *testing.T) {
newFile := models.File{
Id: "newFileId",
Name: "newFileName",
Size: "3GB",
SHA1: "newSHA1",
PasswordHash: "newPassword",
HotlinkId: "newHotlink",
ContentType: "newContent",
AwsBucket: "newAws",
ExpireAt: 123456,
SizeBytes: 456789,
DownloadsRemaining: 11,
DownloadCount: 2,
Encryption: models.EncryptionInfo{
IsEncrypted: true,
IsEndToEndEncrypted: true,
DecryptionKey: []byte("newDecryptionKey"),
Nonce: []byte("newDecryptionNonce"),
},
UnlimitedDownloads: true,
UnlimitedTime: true,
}
dbInstance.SaveMetaData(newFile)
dbInstance.IncreaseDownloadCount(newFile.Id, false)
retrievedFile, ok := dbInstance.GetMetaDataById(newFile.Id)
test.IsEqualBool(t, ok, true)
test.IsEqualInt(t, retrievedFile.DownloadCount, 3)
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 11)
newFile.DownloadCount = 3
test.IsEqual(t, retrievedFile, newFile)
dbInstance.IncreaseDownloadCount(newFile.Id, true)
retrievedFile, ok = dbInstance.GetMetaDataById(newFile.Id)
test.IsEqualBool(t, ok, true)
test.IsEqualInt(t, retrievedFile.DownloadCount, 4)
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 10)
newFile.DownloadCount = 4
newFile.DownloadsRemaining = 10
test.IsEqual(t, retrievedFile, newFile)
dbInstance.DeleteMetaData(newFile.Id)
}
func TestE2EConfig(t *testing.T) {
e2econfig := models.E2EInfoEncrypted{
Version: 1,
@@ -1,8 +1,6 @@
package redis
import (
"bytes"
"encoding/gob"
"github.com/forceu/gokapi/internal/helper"
"github.com/forceu/gokapi/internal/models"
redigo "github.com/gomodule/redigo/redis"
@@ -13,28 +11,29 @@ const (
prefixMetaData = "fmeta:"
)
func dbToMetaData(input []byte) models.File {
var result models.File
buf := bytes.NewBuffer(input)
dec := gob.NewDecoder(buf)
err := dec.Decode(&result)
helper.Check(err)
return result
}
// GetAllMetadata returns a map of all available files
func (p DatabaseProvider) GetAllMetadata() map[string]models.File {
result := make(map[string]models.File)
allMetaData := p.getAllValuesWithPrefix(prefixMetaData)
for _, metaData := range allMetaData {
content, err := redigo.Bytes(metaData, nil)
maps := p.getAllHashesWithPrefix(prefixMetaData)
for k, v := range maps {
file, err := newDbToMetadata(k, v)
helper.Check(err)
file := dbToMetaData(content)
result[file.Id] = file
}
return result
}
func newDbToMetadata(id string, input []any) (models.File, error) {
var result models.File
err := redigo.ScanStruct(input, &result)
if err != nil {
return models.File{}, err
}
result.Id = strings.Replace(id, prefixMetaData, "", 1)
err = result.RedisToFile()
return result, err
}
// GetAllMetaDataIds returns all Ids that contain metadata
func (p DatabaseProvider) GetAllMetaDataIds() []string {
result := make([]string, 0)
@@ -46,23 +45,31 @@ func (p DatabaseProvider) GetAllMetaDataIds() []string {
// GetMetaDataById returns a models.File from the ID passed or false if the id is not valid
func (p DatabaseProvider) GetMetaDataById(id string) (models.File, bool) {
input, ok := p.getKeyBytes(prefixMetaData + id)
result, ok := p.getHashMap(prefixMetaData + id)
if !ok {
return models.File{}, false
}
return dbToMetaData(input), true
file, err := newDbToMetadata(id, result)
helper.Check(err)
return file, true
}
// SaveMetaData stores the metadata of a file to the disk
func (p DatabaseProvider) SaveMetaData(file models.File) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(file)
err := file.FileToRedis()
helper.Check(err)
p.setKey(prefixMetaData+file.Id, buf.Bytes())
p.setHashMap(p.buildArgs(prefixMetaData + file.Id).AddFlat(file))
}
// DeleteMetaData deletes information about a file
func (p DatabaseProvider) DeleteMetaData(id string) {
p.deleteKey(prefixMetaData + id)
}
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
func (p DatabaseProvider) IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
if decreaseRemainingDownloads {
p.decreaseHashmapIntField(prefixMetaData+id, "DownloadsRemaining")
}
p.increaseHashmapIntField(prefixMetaData+id, "DownloadCount")
}
@@ -207,15 +207,6 @@ func (p DatabaseProvider) createNewDatabase() error {
"ValidUntil" INTEGER NOT NULL,
PRIMARY KEY("Id")
) WITHOUT ROWID;
CREATE TABLE "UploadConfig" (
"id" INTEGER NOT NULL UNIQUE,
"Downloads" INTEGER,
"TimeExpiry" INTEGER,
"Password" TEXT,
"UnlimitedDownloads" INTEGER,
"UnlimitedTime" INTEGER,
PRIMARY KEY("id")
);
CREATE TABLE "UploadStatus" (
"ChunkId" TEXT NOT NULL UNIQUE,
"CurrentStatus" INTEGER NOT NULL,
@@ -176,6 +176,49 @@ func TestHotlink(t *testing.T) {
test.IsEqualInt(t, len(hotlinks), 3)
}
func TestDatabaseProvider_IncreaseDownloadCount(t *testing.T) {
newFile := models.File{
Id: "newFileId",
Name: "newFileName",
Size: "3GB",
SHA1: "newSHA1",
PasswordHash: "newPassword",
HotlinkId: "newHotlink",
ContentType: "newContent",
AwsBucket: "newAws",
ExpireAt: 123456,
SizeBytes: 456789,
DownloadsRemaining: 11,
DownloadCount: 2,
Encryption: models.EncryptionInfo{
IsEncrypted: true,
IsEndToEndEncrypted: true,
DecryptionKey: []byte("newDecryptionKey"),
Nonce: []byte("newDecryptionNonce"),
},
UnlimitedDownloads: true,
UnlimitedTime: true,
}
dbInstance.SaveMetaData(newFile)
dbInstance.IncreaseDownloadCount(newFile.Id, false)
retrievedFile, ok := dbInstance.GetMetaDataById(newFile.Id)
test.IsEqualBool(t, ok, true)
test.IsEqualInt(t, retrievedFile.DownloadCount, 3)
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 11)
newFile.DownloadCount = 3
test.IsEqual(t, retrievedFile, newFile)
dbInstance.IncreaseDownloadCount(newFile.Id, true)
retrievedFile, ok = dbInstance.GetMetaDataById(newFile.Id)
test.IsEqualBool(t, ok, true)
test.IsEqualInt(t, retrievedFile.DownloadCount, 4)
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 10)
newFile.DownloadCount = 4
newFile.DownloadsRemaining = 10
test.IsEqual(t, retrievedFile, newFile)
dbInstance.DeleteMetaData(newFile.Id)
}
func TestApiKey(t *testing.T) {
dbInstance.SaveApiKey(models.ApiKey{
Id: "newkey",
@@ -158,6 +158,18 @@ func (p DatabaseProvider) SaveMetaData(file models.File) {
helper.Check(err)
}
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
func (p DatabaseProvider) IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
if decreaseRemainingDownloads {
_, err := p.sqliteDb.Exec(`UPDATE FileMetaData SET DownloadCount = DownloadCount + 1,
DownloadsRemaining = DownloadsRemaining - 1 WHERE id = ?`, id)
helper.Check(err)
} else {
_, err := p.sqliteDb.Exec(`UPDATE FileMetaData SET DownloadCount = DownloadCount + 1 WHERE id = ?`, id)
helper.Check(err)
}
}
// DeleteMetaData deletes information about a file
func (p DatabaseProvider) DeleteMetaData(id string) {
_, err := p.sqliteDb.Exec("DELETE FROM FileMetaData WHERE Id = ?", id)
+51 -20
View File
@@ -1,6 +1,8 @@
package models
import (
"bytes"
"encoding/gob"
"encoding/json"
"fmt"
"github.com/jinzhu/copier"
@@ -9,22 +11,51 @@ import (
// File is a struct used for saving information about an uploaded file
type File struct {
Id string `json:"Id"` // The internal ID of the file
Name string `json:"Name"` // The filename. Will be 'Encrypted file' for end-to-end encrypted files
Size string `json:"Size"` // Filesize in a human-readable format
SHA1 string `json:"SHA1"` // The hash of the file, used for deduplication
PasswordHash string `json:"PasswordHash"` // The hash of the password (if the file is password protected)
HotlinkId string `json:"HotlinkId"` // If file is a picture file and can be hotlinked, this is the ID for the hotlink
ContentType string `json:"ContentType"` // The MIME type for the file
AwsBucket string `json:"AwsBucket"` // If the file is stored in the cloud, this is the bucket that is being used
ExpireAtString string `json:"ExpireAtString"` // Time expiry in a human-readable format in local time
ExpireAt int64 `json:"ExpireAt"` // "UTC timestamp of file expiry
SizeBytes int64 `json:"SizeBytes"` // Filesize in bytes
DownloadsRemaining int `json:"DownloadsRemaining"` // The remaining downloads for this file
DownloadCount int `json:"DownloadCount"` // The amount of times the file has been downloaded
Encryption EncryptionInfo `json:"Encryption"` // If the file is encrypted, this stores all info for decrypting
UnlimitedDownloads bool `json:"UnlimitedDownloads"` // True if the uploader did not limit the downloads
UnlimitedTime bool `json:"UnlimitedTime"` // True if the uploader did not limit the time
Id string `json:"Id" redis:"Id"` // The internal ID of the file
Name string `json:"Name" redis:"Name"` // The filename. Will be 'Encrypted file' for end-to-end encrypted files
Size string `json:"Size" redis:"Size"` // Filesize in a human-readable format
SHA1 string `json:"SHA1" redis:"SHA1"` // The hash of the file, used for deduplication
PasswordHash string `json:"PasswordHash" redis:"PasswordHash"` // The hash of the password (if the file is password protected)
HotlinkId string `json:"HotlinkId" redis:"HotlinkId"` // If file is a picture file and can be hotlinked, this is the ID for the hotlink
ContentType string `json:"ContentType" redis:"ContentType"` // The MIME type for the file
AwsBucket string `json:"AwsBucket" redis:"AwsBucket"` // If the file is stored in the cloud, this is the bucket that is being used
ExpireAtString string `json:"ExpireAtString" redis:"ExpireAtString"` // Time expiry in a human-readable format in local time
ExpireAt int64 `json:"ExpireAt" redis:"ExpireAt"` // "UTC timestamp of file expiry
SizeBytes int64 `json:"SizeBytes" redis:"SizeBytes"` // Filesize in bytes
DownloadsRemaining int `json:"DownloadsRemaining" redis:"DownloadsRemaining"` // The remaining downloads for this file
DownloadCount int `json:"DownloadCount" redis:"DownloadCount"` // The amount of times the file has been downloaded
Encryption EncryptionInfo `json:"Encryption" redis:"-"` // If the file is encrypted, this stores all info for decrypting
UnlimitedDownloads bool `json:"UnlimitedDownloads" redis:"UnlimitedDownloads"` // True if the uploader did not limit the downloads
UnlimitedTime bool `json:"UnlimitedTime" redis:"UnlimitedTime"` // True if the uploader did not limit the time
InternalRedisEncryption []byte `redis:"EncryptionRedis"` // This field is an internal field, used to store the EncryptionInfo in a Redis Hashmap
}
func (f *File) FileToRedis() error {
var encInfo bytes.Buffer
enc := gob.NewEncoder(&encInfo)
err := enc.Encode(f.Encryption)
if err != nil {
return err
}
f.InternalRedisEncryption = encInfo.Bytes()
return nil
}
func (f *File) RedisToFile() error {
if f.InternalRedisEncryption == nil {
f.Encryption = EncryptionInfo{}
return nil
}
var result EncryptionInfo
buf := bytes.NewBuffer(f.InternalRedisEncryption)
dec := gob.NewDecoder(buf)
err := dec.Decode(&result)
if err != nil {
return err
}
f.Encryption = result
f.InternalRedisEncryption = nil
return nil
}
// FileApiOutput will be displayed for public outputs from the ID, hiding sensitive information
@@ -51,10 +82,10 @@ type FileApiOutput struct {
// EncryptionInfo holds information about the encryption used on the file
type EncryptionInfo struct {
IsEncrypted bool `json:"IsEncrypted"`
IsEndToEndEncrypted bool `json:"IsEndToEndEncrypted"`
DecryptionKey []byte `json:"DecryptionKey"`
Nonce []byte `json:"Nonce"`
IsEncrypted bool `json:"IsEncrypted" redis:"IsEncrypted"`
IsEndToEndEncrypted bool `json:"IsEndToEndEncrypted" redis:"IsEndToEndEncrypted"`
DecryptionKey []byte `json:"DecryptionKey" redis:"DecryptionKey"`
Nonce []byte `json:"Nonce" redis:"Nonce"`
}
// IsLocalStorage returns true if the file is not stored on a remote storage
+1 -1
View File
@@ -523,7 +523,7 @@ func GetFileByHotlink(id string) (models.File, bool) {
func ServeFile(file models.File, w http.ResponseWriter, r *http.Request, forceDownload bool) {
file.DownloadsRemaining = file.DownloadsRemaining - 1
file.DownloadCount = file.DownloadCount + 1
database.SaveMetaData(file)
database.IncreaseDownloadCount(file.Id, !file.UnlimitedDownloads)
logging.AddDownload(&file, r, configuration.Get().SaveIp)
go sse.PublishDownloadCount(file)