mirror of
https://github.com/Forceu/Gokapi.git
synced 2026-05-07 06:49:23 -05:00
Change download count atomically to prevent race condition (#223)
This commit is contained in:
@@ -200,6 +200,11 @@ func DeleteMetaData(id string) {
|
||||
db.DeleteMetaData(id)
|
||||
}
|
||||
|
||||
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
|
||||
func IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
|
||||
db.IncreaseDownloadCount(id, decreaseRemainingDownloads)
|
||||
}
|
||||
|
||||
// Session Section
|
||||
|
||||
// GetSession returns the session with the given ID or false if not a valid ID
|
||||
|
||||
@@ -158,6 +158,24 @@ func TestMetaData(t *testing.T) {
|
||||
runAllTypesCompareOutput(t, func() any { return GetAllMetaDataIds() }, []string{})
|
||||
runAllTypesCompareOutput(t, func() any { return GetAllMetadata() }, map[string]models.File{})
|
||||
runAllTypesCompareTwoOutputs(t, func() (any, any) { return GetMetaDataById("testid") }, models.File{}, false)
|
||||
|
||||
increasedDownload := file
|
||||
increasedDownload.DownloadCount = increasedDownload.DownloadCount + 1
|
||||
|
||||
runAllTypesCompareTwoOutputs(t, func() (any, any) {
|
||||
SaveMetaData(file)
|
||||
IncreaseDownloadCount(file.Id, false)
|
||||
return GetMetaDataById(file.Id)
|
||||
}, increasedDownload, true)
|
||||
|
||||
increasedDownload.DownloadCount = increasedDownload.DownloadCount + 1
|
||||
increasedDownload.DownloadsRemaining = increasedDownload.DownloadsRemaining - 1
|
||||
|
||||
runAllTypesCompareTwoOutputs(t, func() (any, any) {
|
||||
IncreaseDownloadCount(file.Id, true)
|
||||
return GetMetaDataById(file.Id)
|
||||
}, increasedDownload, true)
|
||||
runAllTypesNoOutput(t, func() { DeleteMetaData(file.Id) })
|
||||
}
|
||||
|
||||
func TestUpgrade(t *testing.T) {
|
||||
|
||||
@@ -72,6 +72,8 @@ type Database interface {
|
||||
SaveMetaData(file models.File)
|
||||
// DeleteMetaData deletes information about a file
|
||||
DeleteMetaData(id string)
|
||||
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
|
||||
IncreaseDownloadCount(id string, decreaseRemainingDownloads bool)
|
||||
|
||||
// GetSession returns the session with the given ID or false if not a valid ID
|
||||
GetSession(id string) (models.Session, bool)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/forceu/gokapi/internal/helper"
|
||||
@@ -17,7 +19,7 @@ type DatabaseProvider struct {
|
||||
dbPrefix string
|
||||
}
|
||||
|
||||
const DatabaseSchemeVersion = 2
|
||||
const DatabaseSchemeVersion = 3
|
||||
|
||||
// New returns an instance
|
||||
func New(dbConfig models.DbConnection) (DatabaseProvider, error) {
|
||||
@@ -91,8 +93,35 @@ func newPool(config models.DbConnection) *redigo.Pool {
|
||||
|
||||
// Upgrade migrates the DB to a new Gokapi version, if required
|
||||
func (p DatabaseProvider) Upgrade(currentDbVersion int) {
|
||||
// Currently no upgrade necessary
|
||||
return
|
||||
// < 1.9.6
|
||||
if currentDbVersion < 3 {
|
||||
allFiles := getAllLegacyMetaDataAndDelete(p)
|
||||
for _, file := range allFiles {
|
||||
p.SaveMetaData(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getAllLegacyMetaDataAndDelete(p DatabaseProvider) map[string]models.File {
|
||||
result := make(map[string]models.File)
|
||||
allMetaData := p.getAllValuesWithPrefix(prefixMetaData)
|
||||
for _, metaData := range allMetaData {
|
||||
content, err := redigo.Bytes(metaData, nil)
|
||||
helper.Check(err)
|
||||
file := legacyDbToMetaData(content)
|
||||
result[file.Id] = file
|
||||
p.deleteKey(prefixMetaData + file.Id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func legacyDbToMetaData(input []byte) models.File {
|
||||
var result models.File
|
||||
buf := bytes.NewBuffer(input)
|
||||
dec := gob.NewDecoder(buf)
|
||||
err := dec.Decode(&result)
|
||||
helper.Check(err)
|
||||
return result
|
||||
}
|
||||
|
||||
const keyDbVersion = "dbversion"
|
||||
@@ -260,6 +289,20 @@ func (p DatabaseProvider) deleteKey(id string) {
|
||||
helper.Check(err)
|
||||
}
|
||||
|
||||
func (p DatabaseProvider) increaseHashmapIntField(id string, field string) {
|
||||
conn := p.pool.Get()
|
||||
defer conn.Close()
|
||||
_, err := conn.Do("HINCRBY", p.dbPrefix+id, field, 1)
|
||||
helper.Check(err)
|
||||
}
|
||||
|
||||
func (p DatabaseProvider) decreaseHashmapIntField(id string, field string) {
|
||||
conn := p.pool.Get()
|
||||
defer conn.Close()
|
||||
_, err := conn.Do("HINCRBY", p.dbPrefix+id, field, -1)
|
||||
helper.Check(err)
|
||||
}
|
||||
|
||||
func (p DatabaseProvider) runEval(cmd string) {
|
||||
conn := p.pool.Get()
|
||||
defer conn.Close()
|
||||
|
||||
@@ -241,6 +241,49 @@ func TestApiKeys(t *testing.T) {
|
||||
test.IsEqualBool(t, key.LastUsed == 10, true)
|
||||
}
|
||||
|
||||
func TestDatabaseProvider_IncreaseDownloadCount(t *testing.T) {
|
||||
newFile := models.File{
|
||||
Id: "newFileId",
|
||||
Name: "newFileName",
|
||||
Size: "3GB",
|
||||
SHA1: "newSHA1",
|
||||
PasswordHash: "newPassword",
|
||||
HotlinkId: "newHotlink",
|
||||
ContentType: "newContent",
|
||||
AwsBucket: "newAws",
|
||||
ExpireAt: 123456,
|
||||
SizeBytes: 456789,
|
||||
DownloadsRemaining: 11,
|
||||
DownloadCount: 2,
|
||||
Encryption: models.EncryptionInfo{
|
||||
IsEncrypted: true,
|
||||
IsEndToEndEncrypted: true,
|
||||
DecryptionKey: []byte("newDecryptionKey"),
|
||||
Nonce: []byte("newDecryptionNonce"),
|
||||
},
|
||||
UnlimitedDownloads: true,
|
||||
UnlimitedTime: true,
|
||||
}
|
||||
dbInstance.SaveMetaData(newFile)
|
||||
dbInstance.IncreaseDownloadCount(newFile.Id, false)
|
||||
retrievedFile, ok := dbInstance.GetMetaDataById(newFile.Id)
|
||||
test.IsEqualBool(t, ok, true)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadCount, 3)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 11)
|
||||
newFile.DownloadCount = 3
|
||||
test.IsEqual(t, retrievedFile, newFile)
|
||||
|
||||
dbInstance.IncreaseDownloadCount(newFile.Id, true)
|
||||
retrievedFile, ok = dbInstance.GetMetaDataById(newFile.Id)
|
||||
test.IsEqualBool(t, ok, true)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadCount, 4)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 10)
|
||||
newFile.DownloadCount = 4
|
||||
newFile.DownloadsRemaining = 10
|
||||
test.IsEqual(t, retrievedFile, newFile)
|
||||
dbInstance.DeleteMetaData(newFile.Id)
|
||||
}
|
||||
|
||||
func TestE2EConfig(t *testing.T) {
|
||||
e2econfig := models.E2EInfoEncrypted{
|
||||
Version: 1,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"github.com/forceu/gokapi/internal/helper"
|
||||
"github.com/forceu/gokapi/internal/models"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
@@ -13,28 +11,29 @@ const (
|
||||
prefixMetaData = "fmeta:"
|
||||
)
|
||||
|
||||
func dbToMetaData(input []byte) models.File {
|
||||
var result models.File
|
||||
buf := bytes.NewBuffer(input)
|
||||
dec := gob.NewDecoder(buf)
|
||||
err := dec.Decode(&result)
|
||||
helper.Check(err)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetAllMetadata returns a map of all available files
|
||||
func (p DatabaseProvider) GetAllMetadata() map[string]models.File {
|
||||
result := make(map[string]models.File)
|
||||
allMetaData := p.getAllValuesWithPrefix(prefixMetaData)
|
||||
for _, metaData := range allMetaData {
|
||||
content, err := redigo.Bytes(metaData, nil)
|
||||
maps := p.getAllHashesWithPrefix(prefixMetaData)
|
||||
for k, v := range maps {
|
||||
file, err := newDbToMetadata(k, v)
|
||||
helper.Check(err)
|
||||
file := dbToMetaData(content)
|
||||
result[file.Id] = file
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func newDbToMetadata(id string, input []any) (models.File, error) {
|
||||
var result models.File
|
||||
err := redigo.ScanStruct(input, &result)
|
||||
if err != nil {
|
||||
return models.File{}, err
|
||||
}
|
||||
result.Id = strings.Replace(id, prefixMetaData, "", 1)
|
||||
err = result.RedisToFile()
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetAllMetaDataIds returns all Ids that contain metadata
|
||||
func (p DatabaseProvider) GetAllMetaDataIds() []string {
|
||||
result := make([]string, 0)
|
||||
@@ -46,23 +45,31 @@ func (p DatabaseProvider) GetAllMetaDataIds() []string {
|
||||
|
||||
// GetMetaDataById returns a models.File from the ID passed or false if the id is not valid
|
||||
func (p DatabaseProvider) GetMetaDataById(id string) (models.File, bool) {
|
||||
input, ok := p.getKeyBytes(prefixMetaData + id)
|
||||
result, ok := p.getHashMap(prefixMetaData + id)
|
||||
if !ok {
|
||||
return models.File{}, false
|
||||
}
|
||||
return dbToMetaData(input), true
|
||||
file, err := newDbToMetadata(id, result)
|
||||
helper.Check(err)
|
||||
return file, true
|
||||
}
|
||||
|
||||
// SaveMetaData stores the metadata of a file to the disk
|
||||
func (p DatabaseProvider) SaveMetaData(file models.File) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(file)
|
||||
err := file.FileToRedis()
|
||||
helper.Check(err)
|
||||
p.setKey(prefixMetaData+file.Id, buf.Bytes())
|
||||
p.setHashMap(p.buildArgs(prefixMetaData + file.Id).AddFlat(file))
|
||||
}
|
||||
|
||||
// DeleteMetaData deletes information about a file
|
||||
func (p DatabaseProvider) DeleteMetaData(id string) {
|
||||
p.deleteKey(prefixMetaData + id)
|
||||
}
|
||||
|
||||
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
|
||||
func (p DatabaseProvider) IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
|
||||
if decreaseRemainingDownloads {
|
||||
p.decreaseHashmapIntField(prefixMetaData+id, "DownloadsRemaining")
|
||||
}
|
||||
p.increaseHashmapIntField(prefixMetaData+id, "DownloadCount")
|
||||
}
|
||||
|
||||
@@ -207,15 +207,6 @@ func (p DatabaseProvider) createNewDatabase() error {
|
||||
"ValidUntil" INTEGER NOT NULL,
|
||||
PRIMARY KEY("Id")
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE "UploadConfig" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"Downloads" INTEGER,
|
||||
"TimeExpiry" INTEGER,
|
||||
"Password" TEXT,
|
||||
"UnlimitedDownloads" INTEGER,
|
||||
"UnlimitedTime" INTEGER,
|
||||
PRIMARY KEY("id")
|
||||
);
|
||||
CREATE TABLE "UploadStatus" (
|
||||
"ChunkId" TEXT NOT NULL UNIQUE,
|
||||
"CurrentStatus" INTEGER NOT NULL,
|
||||
|
||||
@@ -176,6 +176,49 @@ func TestHotlink(t *testing.T) {
|
||||
test.IsEqualInt(t, len(hotlinks), 3)
|
||||
}
|
||||
|
||||
func TestDatabaseProvider_IncreaseDownloadCount(t *testing.T) {
|
||||
newFile := models.File{
|
||||
Id: "newFileId",
|
||||
Name: "newFileName",
|
||||
Size: "3GB",
|
||||
SHA1: "newSHA1",
|
||||
PasswordHash: "newPassword",
|
||||
HotlinkId: "newHotlink",
|
||||
ContentType: "newContent",
|
||||
AwsBucket: "newAws",
|
||||
ExpireAt: 123456,
|
||||
SizeBytes: 456789,
|
||||
DownloadsRemaining: 11,
|
||||
DownloadCount: 2,
|
||||
Encryption: models.EncryptionInfo{
|
||||
IsEncrypted: true,
|
||||
IsEndToEndEncrypted: true,
|
||||
DecryptionKey: []byte("newDecryptionKey"),
|
||||
Nonce: []byte("newDecryptionNonce"),
|
||||
},
|
||||
UnlimitedDownloads: true,
|
||||
UnlimitedTime: true,
|
||||
}
|
||||
dbInstance.SaveMetaData(newFile)
|
||||
dbInstance.IncreaseDownloadCount(newFile.Id, false)
|
||||
retrievedFile, ok := dbInstance.GetMetaDataById(newFile.Id)
|
||||
test.IsEqualBool(t, ok, true)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadCount, 3)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 11)
|
||||
newFile.DownloadCount = 3
|
||||
test.IsEqual(t, retrievedFile, newFile)
|
||||
|
||||
dbInstance.IncreaseDownloadCount(newFile.Id, true)
|
||||
retrievedFile, ok = dbInstance.GetMetaDataById(newFile.Id)
|
||||
test.IsEqualBool(t, ok, true)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadCount, 4)
|
||||
test.IsEqualInt(t, retrievedFile.DownloadsRemaining, 10)
|
||||
newFile.DownloadCount = 4
|
||||
newFile.DownloadsRemaining = 10
|
||||
test.IsEqual(t, retrievedFile, newFile)
|
||||
dbInstance.DeleteMetaData(newFile.Id)
|
||||
}
|
||||
|
||||
func TestApiKey(t *testing.T) {
|
||||
dbInstance.SaveApiKey(models.ApiKey{
|
||||
Id: "newkey",
|
||||
|
||||
@@ -158,6 +158,18 @@ func (p DatabaseProvider) SaveMetaData(file models.File) {
|
||||
helper.Check(err)
|
||||
}
|
||||
|
||||
// IncreaseDownloadCount increases the download count of a file, preventing race conditions
|
||||
func (p DatabaseProvider) IncreaseDownloadCount(id string, decreaseRemainingDownloads bool) {
|
||||
if decreaseRemainingDownloads {
|
||||
_, err := p.sqliteDb.Exec(`UPDATE FileMetaData SET DownloadCount = DownloadCount + 1,
|
||||
DownloadsRemaining = DownloadsRemaining - 1 WHERE id = ?`, id)
|
||||
helper.Check(err)
|
||||
} else {
|
||||
_, err := p.sqliteDb.Exec(`UPDATE FileMetaData SET DownloadCount = DownloadCount + 1 WHERE id = ?`, id)
|
||||
helper.Check(err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteMetaData deletes information about a file
|
||||
func (p DatabaseProvider) DeleteMetaData(id string) {
|
||||
_, err := p.sqliteDb.Exec("DELETE FROM FileMetaData WHERE Id = ?", id)
|
||||
|
||||
+51
-20
@@ -1,6 +1,8 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jinzhu/copier"
|
||||
@@ -9,22 +11,51 @@ import (
|
||||
|
||||
// File is a struct used for saving information about an uploaded file
|
||||
type File struct {
|
||||
Id string `json:"Id"` // The internal ID of the file
|
||||
Name string `json:"Name"` // The filename. Will be 'Encrypted file' for end-to-end encrypted files
|
||||
Size string `json:"Size"` // Filesize in a human-readable format
|
||||
SHA1 string `json:"SHA1"` // The hash of the file, used for deduplication
|
||||
PasswordHash string `json:"PasswordHash"` // The hash of the password (if the file is password protected)
|
||||
HotlinkId string `json:"HotlinkId"` // If file is a picture file and can be hotlinked, this is the ID for the hotlink
|
||||
ContentType string `json:"ContentType"` // The MIME type for the file
|
||||
AwsBucket string `json:"AwsBucket"` // If the file is stored in the cloud, this is the bucket that is being used
|
||||
ExpireAtString string `json:"ExpireAtString"` // Time expiry in a human-readable format in local time
|
||||
ExpireAt int64 `json:"ExpireAt"` // "UTC timestamp of file expiry
|
||||
SizeBytes int64 `json:"SizeBytes"` // Filesize in bytes
|
||||
DownloadsRemaining int `json:"DownloadsRemaining"` // The remaining downloads for this file
|
||||
DownloadCount int `json:"DownloadCount"` // The amount of times the file has been downloaded
|
||||
Encryption EncryptionInfo `json:"Encryption"` // If the file is encrypted, this stores all info for decrypting
|
||||
UnlimitedDownloads bool `json:"UnlimitedDownloads"` // True if the uploader did not limit the downloads
|
||||
UnlimitedTime bool `json:"UnlimitedTime"` // True if the uploader did not limit the time
|
||||
Id string `json:"Id" redis:"Id"` // The internal ID of the file
|
||||
Name string `json:"Name" redis:"Name"` // The filename. Will be 'Encrypted file' for end-to-end encrypted files
|
||||
Size string `json:"Size" redis:"Size"` // Filesize in a human-readable format
|
||||
SHA1 string `json:"SHA1" redis:"SHA1"` // The hash of the file, used for deduplication
|
||||
PasswordHash string `json:"PasswordHash" redis:"PasswordHash"` // The hash of the password (if the file is password protected)
|
||||
HotlinkId string `json:"HotlinkId" redis:"HotlinkId"` // If file is a picture file and can be hotlinked, this is the ID for the hotlink
|
||||
ContentType string `json:"ContentType" redis:"ContentType"` // The MIME type for the file
|
||||
AwsBucket string `json:"AwsBucket" redis:"AwsBucket"` // If the file is stored in the cloud, this is the bucket that is being used
|
||||
ExpireAtString string `json:"ExpireAtString" redis:"ExpireAtString"` // Time expiry in a human-readable format in local time
|
||||
ExpireAt int64 `json:"ExpireAt" redis:"ExpireAt"` // "UTC timestamp of file expiry
|
||||
SizeBytes int64 `json:"SizeBytes" redis:"SizeBytes"` // Filesize in bytes
|
||||
DownloadsRemaining int `json:"DownloadsRemaining" redis:"DownloadsRemaining"` // The remaining downloads for this file
|
||||
DownloadCount int `json:"DownloadCount" redis:"DownloadCount"` // The amount of times the file has been downloaded
|
||||
Encryption EncryptionInfo `json:"Encryption" redis:"-"` // If the file is encrypted, this stores all info for decrypting
|
||||
UnlimitedDownloads bool `json:"UnlimitedDownloads" redis:"UnlimitedDownloads"` // True if the uploader did not limit the downloads
|
||||
UnlimitedTime bool `json:"UnlimitedTime" redis:"UnlimitedTime"` // True if the uploader did not limit the time
|
||||
InternalRedisEncryption []byte `redis:"EncryptionRedis"` // This field is an internal field, used to store the EncryptionInfo in a Redis Hashmap
|
||||
}
|
||||
|
||||
func (f *File) FileToRedis() error {
|
||||
var encInfo bytes.Buffer
|
||||
enc := gob.NewEncoder(&encInfo)
|
||||
err := enc.Encode(f.Encryption)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.InternalRedisEncryption = encInfo.Bytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *File) RedisToFile() error {
|
||||
if f.InternalRedisEncryption == nil {
|
||||
f.Encryption = EncryptionInfo{}
|
||||
return nil
|
||||
}
|
||||
var result EncryptionInfo
|
||||
buf := bytes.NewBuffer(f.InternalRedisEncryption)
|
||||
dec := gob.NewDecoder(buf)
|
||||
err := dec.Decode(&result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Encryption = result
|
||||
f.InternalRedisEncryption = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileApiOutput will be displayed for public outputs from the ID, hiding sensitive information
|
||||
@@ -51,10 +82,10 @@ type FileApiOutput struct {
|
||||
|
||||
// EncryptionInfo holds information about the encryption used on the file
|
||||
type EncryptionInfo struct {
|
||||
IsEncrypted bool `json:"IsEncrypted"`
|
||||
IsEndToEndEncrypted bool `json:"IsEndToEndEncrypted"`
|
||||
DecryptionKey []byte `json:"DecryptionKey"`
|
||||
Nonce []byte `json:"Nonce"`
|
||||
IsEncrypted bool `json:"IsEncrypted" redis:"IsEncrypted"`
|
||||
IsEndToEndEncrypted bool `json:"IsEndToEndEncrypted" redis:"IsEndToEndEncrypted"`
|
||||
DecryptionKey []byte `json:"DecryptionKey" redis:"DecryptionKey"`
|
||||
Nonce []byte `json:"Nonce" redis:"Nonce"`
|
||||
}
|
||||
|
||||
// IsLocalStorage returns true if the file is not stored on a remote storage
|
||||
|
||||
@@ -523,7 +523,7 @@ func GetFileByHotlink(id string) (models.File, bool) {
|
||||
func ServeFile(file models.File, w http.ResponseWriter, r *http.Request, forceDownload bool) {
|
||||
file.DownloadsRemaining = file.DownloadsRemaining - 1
|
||||
file.DownloadCount = file.DownloadCount + 1
|
||||
database.SaveMetaData(file)
|
||||
database.IncreaseDownloadCount(file.Id, !file.UnlimitedDownloads)
|
||||
logging.AddDownload(&file, r, configuration.Get().SaveIp)
|
||||
go sse.PublishDownloadCount(file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user