Refactoring, using mutex to prevent race conditions for all configuration reads/writes

This commit is contained in:
Marc Ole Bulling
2021-04-30 15:12:34 +02:00
parent e75b35f921
commit 3ebcafbb81
12 changed files with 291 additions and 181 deletions

2
go.mod
View File

@@ -3,7 +3,7 @@ module Gokapi
go 1.16
require (
github.com/otiai10/copy v1.5.1
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
)

11
go.sum
View File

@@ -1,10 +1,7 @@
github.com/otiai10/copy v1.5.1 h1:a/cs2E1/1V0az8K5nblbl+ymEa4E11AfaOLMar8V34w=
github.com/otiai10/copy v1.5.1/go.mod h1:XWfuS3CrI0R6IE0FbgHsEazaXO8G0LpMp9o8tos0x4E=
github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE=
github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs=
github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo=
github.com/otiai10/mint v1.3.2 h1:VYWnrP5fXmz1MXvjuUvcBrXSjGE6xjON+axB/UrpO3E=
github.com/otiai10/mint v1.3.2/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
github.com/sasha-s/go-deadlock v0.2.0 h1:lMqc+fUb7RrFS3gQLtoQsJ7/6TV/pAIFvBsqX73DK8Y=
github.com/sasha-s/go-deadlock v0.2.0/go.mod h1:StQn567HiB1fF2yJ44N9au7wOhrPS3iZqiDbRupzT10=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=

View File

@@ -30,7 +30,7 @@ const minLengthPassword = 6
var Environment environment.Environment
// ServerSettings is an object containing the server configuration
var ServerSettings Configuration
var serverSettings Configuration
// Version of the configuration structure. Used for upgrading
const currentConfigVersion = 6
@@ -71,54 +71,66 @@ func Load() {
helper.Check(err)
defer file.Close()
decoder := json.NewDecoder(file)
ServerSettings = Configuration{}
err = decoder.Decode(&ServerSettings)
serverSettings = Configuration{}
err = decoder.Decode(&serverSettings)
helper.Check(err)
updateConfig()
helper.CreateDir(ServerSettings.DataDir)
helper.CreateDir(serverSettings.DataDir)
}
// LockSessions locks sessions to prevent race conditions (blocking)
func LockSessions() {
// Lock locks configuration to prevent race conditions (blocking)
func Lock() {
mutexSessions.Lock()
}
// UnlockSessionsAndSave unlocks sessions and saves the configuration
func UnlockSessionsAndSave() {
// ReleaseAndSave unlocks and saves the configuration
func ReleaseAndSave() {
Save()
mutexSessions.Unlock()
}
// Release unlocks the configuration
func Release() {
mutexSessions.Unlock()
}
// GetServerSettings locks the settings returns a pointer to the configuration
// Release needs to be called when finished with the operation!
func GetServerSettings() *Configuration {
mutexSessions.Lock()
return &serverSettings
}
// Upgrades the ServerSettings if saved with a previous version
func updateConfig() {
// < v1.1.2
if ServerSettings.ConfigVersion < 3 {
ServerSettings.SaltAdmin = "eefwkjqweduiotbrkl##$2342brerlk2321"
ServerSettings.SaltFiles = "P1UI5sRNDwuBgOvOYhNsmucZ2pqo4KEvOoqqbpdu"
ServerSettings.LengthId = 15
ServerSettings.DataDir = Environment.DataDir
if serverSettings.ConfigVersion < 3 {
serverSettings.SaltAdmin = "eefwkjqweduiotbrkl##$2342brerlk2321"
serverSettings.SaltFiles = "P1UI5sRNDwuBgOvOYhNsmucZ2pqo4KEvOoqqbpdu"
serverSettings.LengthId = 15
serverSettings.DataDir = Environment.DataDir
}
// < v1.1.3
if ServerSettings.ConfigVersion < 4 {
ServerSettings.Hotlinks = make(map[string]models.Hotlink)
if serverSettings.ConfigVersion < 4 {
serverSettings.Hotlinks = make(map[string]models.Hotlink)
}
// < v1.1.4
if ServerSettings.ConfigVersion < 5 {
ServerSettings.LengthId = 15
ServerSettings.DownloadStatus = make(map[string]models.DownloadStatus)
for _, file := range ServerSettings.Files {
if serverSettings.ConfigVersion < 5 {
serverSettings.LengthId = 15
serverSettings.DownloadStatus = make(map[string]models.DownloadStatus)
for _, file := range serverSettings.Files {
file.ContentType = "application/octet-stream"
ServerSettings.Files[file.Id] = file
serverSettings.Files[file.Id] = file
}
}
// < v1.2.0
if ServerSettings.ConfigVersion < 6 {
ServerSettings.ApiKeys = make(map[string]models.ApiKey)
if serverSettings.ConfigVersion < 6 {
serverSettings.ApiKeys = make(map[string]models.ApiKey)
}
if ServerSettings.ConfigVersion < currentConfigVersion {
if serverSettings.ConfigVersion < currentConfigVersion {
fmt.Println("Successfully upgraded database")
ServerSettings.ConfigVersion = currentConfigVersion
serverSettings.ConfigVersion = currentConfigVersion
Save()
}
}
@@ -130,7 +142,7 @@ func generateDefaultConfig() {
if saltAdmin == "" {
saltAdmin = helper.GenerateRandomString(30)
}
ServerSettings = Configuration{
serverSettings = Configuration{
SaltAdmin: saltAdmin,
}
username := askForUsername()
@@ -148,7 +160,7 @@ func generateDefaultConfig() {
saltFiles = helper.GenerateRandomString(30)
}
ServerSettings = Configuration{
serverSettings = Configuration{
Port: bindAddress,
AdminName: username,
AdminPassword: HashPassword(password, false),
@@ -179,7 +191,7 @@ func Save() {
}
defer file.Close()
encoder := json.NewEncoder(file)
err = encoder.Encode(&ServerSettings)
err = encoder.Encode(&serverSettings)
if err != nil {
fmt.Println("Error writing configuration:", err)
os.Exit(1)
@@ -344,7 +356,7 @@ func addTrailingSlash(url string) string {
// DisplayPasswordReset shows a password prompt in the CLI and saves the new password
func DisplayPasswordReset() {
ServerSettings.AdminPassword = HashPassword(askForPassword(), false)
serverSettings.AdminPassword = HashPassword(askForPassword(), false)
Save()
}
@@ -353,12 +365,22 @@ func HashPassword(password string, useFileSalt bool) string {
if password == "" {
return ""
}
salt := ServerSettings.SaltAdmin
salt := serverSettings.SaltAdmin
if useFileSalt {
salt = ServerSettings.SaltFiles
salt = serverSettings.SaltFiles
}
bytes := []byte(password + salt)
hash := sha1.New()
hash.Write(bytes)
return hex.EncodeToString(hash.Sum(nil))
}
// GetLengthId returns the length of the file IDs to be generated
func GetLengthId() int {
return serverSettings.LengthId
}
// GetSessions returns a pointer to the session storage
func GetSessions() *map[string]models.Session {
return &serverSettings.Sessions
}

View File

@@ -2,8 +2,8 @@ package configuration
import (
"Gokapi/internal/environment"
"Gokapi/internal/test"
"Gokapi/internal/test/testconfiguration"
"Gokapi/internal/test"
"Gokapi/internal/test/testconfiguration"
"os"
"testing"
"time"
@@ -19,31 +19,31 @@ func TestMain(m *testing.M) {
func TestLoad(t *testing.T) {
Load()
test.IsEqualString(t, Environment.ConfigDir, "test")
test.IsEqualString(t, ServerSettings.Port, "127.0.0.1:53843")
test.IsEqualString(t, ServerSettings.AdminName, "test")
test.IsEqualString(t, ServerSettings.ServerUrl, "http://127.0.0.1:53843/")
test.IsEqualString(t, ServerSettings.AdminPassword, "10340aece68aa4fb14507ae45b05506026f276cf")
test.IsEqualString(t, serverSettings.Port, "127.0.0.1:53843")
test.IsEqualString(t, serverSettings.AdminName, "test")
test.IsEqualString(t, serverSettings.ServerUrl, "http://127.0.0.1:53843/")
test.IsEqualString(t, serverSettings.AdminPassword, "10340aece68aa4fb14507ae45b05506026f276cf")
test.IsEqualString(t, HashPassword("testtest", false), "10340aece68aa4fb14507ae45b05506026f276cf")
test.IsEqualInt(t, ServerSettings.LengthId, 20)
test.IsEqualInt(t, serverSettings.LengthId, 20)
}
func TestMutex(t *testing.T) {
finished := make(chan bool)
oldValue := ServerSettings.ConfigVersion
oldValue := serverSettings.ConfigVersion
go func() {
time.Sleep(100 * time.Millisecond)
LockSessions()
test.IsEqualInt(t, ServerSettings.ConfigVersion, -9)
ServerSettings.ConfigVersion = oldValue
UnlockSessionsAndSave()
test.IsEqualInt(t, ServerSettings.ConfigVersion, oldValue)
Lock()
test.IsEqualInt(t, serverSettings.ConfigVersion, -9)
serverSettings.ConfigVersion = oldValue
ReleaseAndSave()
test.IsEqualInt(t, serverSettings.ConfigVersion, oldValue)
finished <- true
}()
LockSessions()
ServerSettings.ConfigVersion = -9
Lock()
serverSettings.ConfigVersion = -9
time.Sleep(150 * time.Millisecond)
test.IsEqualInt(t, ServerSettings.ConfigVersion, -9)
UnlockSessionsAndSave()
test.IsEqualInt(t, serverSettings.ConfigVersion, -9)
ReleaseAndSave()
<-finished
}
@@ -58,18 +58,18 @@ func TestCreateNewConfig(t *testing.T) {
os.Setenv("GOKAPI_LOCALHOST", "false")
Load()
test.IsEqualString(t, Environment.ConfigDir, "test")
test.IsEqualString(t, ServerSettings.Port, ":1234")
test.IsEqualString(t, ServerSettings.AdminName, "test2")
test.IsEqualString(t, ServerSettings.ServerUrl, "http://test.com/")
test.IsEqualString(t, ServerSettings.RedirectUrl, "http://test2.com")
test.IsEqualString(t, ServerSettings.AdminPassword, "5bbf5684437a4c658d2e0890d784694afb63f715")
test.IsEqualString(t, serverSettings.Port, ":1234")
test.IsEqualString(t, serverSettings.AdminName, "test2")
test.IsEqualString(t, serverSettings.ServerUrl, "http://test.com/")
test.IsEqualString(t, serverSettings.RedirectUrl, "http://test2.com")
test.IsEqualString(t, serverSettings.AdminPassword, "5bbf5684437a4c658d2e0890d784694afb63f715")
test.IsEqualString(t, HashPassword("testtest2", false), "5bbf5684437a4c658d2e0890d784694afb63f715")
test.IsEqualInt(t, ServerSettings.LengthId, 15)
test.IsEqualInt(t, serverSettings.LengthId, 15)
os.Remove("test/config.json")
os.Unsetenv("GOKAPI_SALT_ADMIN")
Load()
test.IsEqualInt(t, len(ServerSettings.SaltAdmin), 30)
test.IsNotEqualString(t, ServerSettings.SaltAdmin, "eefwkjqweduiotbrkl##$2342brerlk2321")
test.IsEqualInt(t, len(serverSettings.SaltAdmin), 30)
test.IsNotEqualString(t, serverSettings.SaltAdmin, "eefwkjqweduiotbrkl##$2342brerlk2321")
os.Unsetenv("GOKAPI_USERNAME")
os.Unsetenv("GOKAPI_PASSWORD")
os.Unsetenv("GOKAPI_PORT")
@@ -81,14 +81,14 @@ func TestCreateNewConfig(t *testing.T) {
func TestUpgradeDb(t *testing.T) {
testconfiguration.WriteUpgradeConfigFile()
Load()
test.IsEqualString(t, ServerSettings.SaltAdmin, "eefwkjqweduiotbrkl##$2342brerlk2321")
test.IsEqualString(t, ServerSettings.SaltFiles, "P1UI5sRNDwuBgOvOYhNsmucZ2pqo4KEvOoqqbpdu")
test.IsEqualString(t, ServerSettings.DataDir, Environment.DataDir)
test.IsEqualInt(t, ServerSettings.LengthId, 15)
test.IsEqualBool(t, ServerSettings.Hotlinks == nil, false)
test.IsEqualBool(t, ServerSettings.DownloadStatus == nil, false)
test.IsEqualString(t, ServerSettings.Files["MgXJLe4XLfpXcL12ec4i"].ContentType, "application/octet-stream")
test.IsEqualInt(t, ServerSettings.ConfigVersion, currentConfigVersion)
test.IsEqualString(t, serverSettings.SaltAdmin, "eefwkjqweduiotbrkl##$2342brerlk2321")
test.IsEqualString(t, serverSettings.SaltFiles, "P1UI5sRNDwuBgOvOYhNsmucZ2pqo4KEvOoqqbpdu")
test.IsEqualString(t, serverSettings.DataDir, Environment.DataDir)
test.IsEqualInt(t, serverSettings.LengthId, 15)
test.IsEqualBool(t, serverSettings.Hotlinks == nil, false)
test.IsEqualBool(t, serverSettings.DownloadStatus == nil, false)
test.IsEqualString(t, serverSettings.Files["MgXJLe4XLfpXcL12ec4i"].ContentType, "application/octet-stream")
test.IsEqualInt(t, serverSettings.ConfigVersion, currentConfigVersion)
testconfiguration.Create(false)
Load()
}

View File

@@ -10,23 +10,29 @@ import (
// SetDownload creates a new DownloadStatus struct and returns its Id
func SetDownload(file models.File) string {
status := newDownloadStatus(file)
configuration.ServerSettings.DownloadStatus[status.Id] = status
settings := configuration.GetServerSettings()
settings.DownloadStatus[status.Id] = status
configuration.ReleaseAndSave()
return status.Id
}
// SetComplete removes the download object
func SetComplete(id string) {
delete(configuration.ServerSettings.DownloadStatus, id)
settings := configuration.GetServerSettings()
delete(settings.DownloadStatus, id)
configuration.ReleaseAndSave()
}
// Clean removes all expires status objects
func Clean() {
settings := configuration.GetServerSettings()
now := time.Now().Unix()
for _, item := range configuration.ServerSettings.DownloadStatus {
for _, item := range settings.DownloadStatus {
if item.ExpireAt < now {
delete(configuration.ServerSettings.DownloadStatus, item.Id)
delete(settings.DownloadStatus, item.Id)
}
}
configuration.Release()
}
// newDownloadStatus initialises the a new DownloadStatus item
@@ -40,8 +46,8 @@ func newDownloadStatus(file models.File) models.DownloadStatus {
}
// IsCurrentlyDownloading returns true if file is currently being downloaded
func IsCurrentlyDownloading(file models.File) bool {
for _, status := range configuration.ServerSettings.DownloadStatus {
func IsCurrentlyDownloading(file models.File, settings *configuration.Configuration) bool {
for _, status := range settings.DownloadStatus {
if status.FileId == file.Id {
if status.ExpireAt > time.Now().Unix() {
return true

View File

@@ -4,6 +4,7 @@ import (
"Gokapi/internal/configuration"
"Gokapi/internal/models"
"Gokapi/internal/test"
"Gokapi/internal/test/testconfiguration"
"os"
"testing"
"time"
@@ -13,7 +14,10 @@ var testFile models.File
var statusId string
func TestMain(m *testing.M) {
configuration.ServerSettings.DownloadStatus = make(map[string]models.DownloadStatus)
testconfiguration.Create(false)
configuration.Load()
settings := configuration.GetServerSettings()
settings.DownloadStatus = make(map[string]models.DownloadStatus)
testFile = models.File{
Id: "test",
Name: "testName",
@@ -23,7 +27,9 @@ func TestMain(m *testing.M) {
ExpireAtString: "expire",
DownloadsRemaining: 1,
}
configuration.Release()
exitVal := m.Run()
testconfiguration.Delete()
os.Exit(exitVal)
}
@@ -36,7 +42,9 @@ func TestNewDownloadStatus(t *testing.T) {
func TestSetDownload(t *testing.T) {
statusId = SetDownload(testFile)
status := configuration.ServerSettings.DownloadStatus[statusId]
settings := configuration.GetServerSettings()
status := settings.DownloadStatus[statusId]
configuration.Release()
test.IsNotEmpty(t, status.Id)
test.IsEqualString(t, status.Id, statusId)
test.IsEqualString(t, status.FileId, testFile.Id)
@@ -44,27 +52,33 @@ func TestSetDownload(t *testing.T) {
}
func TestSetComplete(t *testing.T) {
status := configuration.ServerSettings.DownloadStatus[statusId]
settings := configuration.GetServerSettings()
status := settings.DownloadStatus[statusId]
configuration.Release()
test.IsNotEmpty(t, status.Id)
SetComplete(statusId)
status = configuration.ServerSettings.DownloadStatus[statusId]
status = settings.DownloadStatus[statusId]
test.IsEmpty(t, status.Id)
}
func TestIsCurrentlyDownloading(t *testing.T) {
statusId = SetDownload(testFile)
test.IsEqualBool(t, IsCurrentlyDownloading(testFile), true)
test.IsEqualBool(t, IsCurrentlyDownloading(models.File{Id: "notDownloading"}), false)
settings := configuration.GetServerSettings()
configuration.Release()
test.IsEqualBool(t, IsCurrentlyDownloading(testFile, settings), true)
test.IsEqualBool(t, IsCurrentlyDownloading(models.File{Id: "notDownloading"}, settings), false)
}
func TestClean(t *testing.T) {
test.IsEqualInt(t, len(configuration.ServerSettings.DownloadStatus), 1)
settings := configuration.GetServerSettings()
configuration.Release()
test.IsEqualInt(t, len(settings.DownloadStatus), 1)
Clean()
test.IsEqualInt(t, len(configuration.ServerSettings.DownloadStatus), 1)
status := configuration.ServerSettings.DownloadStatus[statusId]
test.IsEqualInt(t, len(settings.DownloadStatus), 1)
status := settings.DownloadStatus[statusId]
status.ExpireAt = 1
configuration.ServerSettings.DownloadStatus[statusId] = status
test.IsEqualInt(t, len(configuration.ServerSettings.DownloadStatus), 1)
settings.DownloadStatus[statusId] = status
test.IsEqualInt(t, len(settings.DownloadStatus), 1)
Clean()
test.IsEqualInt(t, len(configuration.ServerSettings.DownloadStatus), 0)
test.IsEqualInt(t, len(settings.DownloadStatus), 0)
}

View File

@@ -1,11 +1,13 @@
package models
// ApiKey contains data of a single api key
type ApiKey struct {
Id string `json:"Id"`
FriendlyName string `json:"FriendlyName"`
LastUsed int64 `json:"LastUsed"`
}
// UploadItem is the result for the "list uploads" api call
type UploadItem struct {
Id string `json:"Id"`
Name string `json:"Name"`

View File

@@ -31,7 +31,7 @@ func NewFile(fileContent io.Reader, fileHeader *multipart.FileHeader, expireAt i
if err != nil {
return models.File{}, err
}
id := helper.GenerateRandomString(configuration.ServerSettings.LengthId)
id := helper.GenerateRandomString(configuration.GetLengthId())
hash := sha1.New()
hash.Write(fileBytes)
file := models.File{
@@ -46,9 +46,11 @@ func NewFile(fileContent io.Reader, fileHeader *multipart.FileHeader, expireAt i
ContentType: fileHeader.Header.Get("Content-Type"),
}
addHotlink(&file)
configuration.ServerSettings.Files[id] = file
filename := configuration.ServerSettings.DataDir + "/" + file.SHA256
if !helper.FileExists(configuration.ServerSettings.DataDir + "/" + file.SHA256) {
settings := configuration.GetServerSettings()
defer func() { configuration.ReleaseAndSave() }()
settings.Files[id] = file
filename := settings.DataDir + "/" + file.SHA256
if !helper.FileExists(settings.DataDir + "/" + file.SHA256) {
destinationFile, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return models.File{}, err
@@ -56,7 +58,6 @@ func NewFile(fileContent io.Reader, fileHeader *multipart.FileHeader, expireAt i
defer destinationFile.Close()
destinationFile.Write(fileBytes)
}
configuration.Save()
return file, nil
}
@@ -70,10 +71,12 @@ func addHotlink(file *models.File) {
}
link := helper.GenerateRandomString(40) + extension
file.HotlinkId = link
configuration.ServerSettings.Hotlinks[link] = models.Hotlink{
settings := configuration.GetServerSettings()
settings.Hotlinks[link] = models.Hotlink{
Id: link,
FileId: file.Id,
}
configuration.Release()
}
// GetFile gets the file by id. Returns (empty File, false) if invalid / expired file
@@ -83,11 +86,13 @@ func GetFile(id string) (models.File, bool) {
if id == "" {
return emptyResult, false
}
file := configuration.ServerSettings.Files[id]
settings := configuration.GetServerSettings()
file := settings.Files[id]
configuration.Release()
if file.ExpireAt < time.Now().Unix() || file.DownloadsRemaining < 1 {
return emptyResult, false
}
if !helper.FileExists(configuration.ServerSettings.DataDir + "/" + file.SHA256) {
if !helper.FileExists(settings.DataDir + "/" + file.SHA256) {
return emptyResult, false
}
return file, true
@@ -100,15 +105,19 @@ func GetFileByHotlink(id string) (models.File, bool) {
if id == "" {
return emptyResult, false
}
hotlink := configuration.ServerSettings.Hotlinks[id]
settings := configuration.GetServerSettings()
hotlink := settings.Hotlinks[id]
configuration.Release()
return GetFile(hotlink.FileId)
}
// ServeFile subtracts a download allowance and serves the file to the browser
func ServeFile(file models.File, w http.ResponseWriter, r *http.Request, forceDownload bool) {
file.DownloadsRemaining = file.DownloadsRemaining - 1
configuration.ServerSettings.Files[file.Id] = file
storageData, err := os.OpenFile(configuration.ServerSettings.DataDir+"/"+file.SHA256, os.O_RDONLY, 0644)
settings := configuration.GetServerSettings()
settings.Files[file.Id] = file
storageData, err := os.OpenFile(settings.DataDir+"/"+file.SHA256, os.O_RDONLY, 0644)
configuration.Release()
helper.Check(err)
defer storageData.Close()
size, err := helper.GetFileSize(storageData)
@@ -119,10 +128,8 @@ func ServeFile(file models.File, w http.ResponseWriter, r *http.Request, forceDo
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.Header().Set("Content-Type", file.ContentType)
statusId := downloadstatus.SetDownload(file)
configuration.Save()
http.ServeContent(w, r, file.Name, time.Now(), storageData)
downloadstatus.SetComplete(statusId)
configuration.Save()
}
// CleanUp removes expired files from the config and from the filesystem if they are not referenced by other files anymore
@@ -132,28 +139,30 @@ func CleanUp(periodic bool) {
downloadstatus.Clean()
timeNow := time.Now().Unix()
wasItemDeleted := false
for key, element := range configuration.ServerSettings.Files {
fileExists := helper.FileExists(configuration.ServerSettings.DataDir + "/" + element.SHA256)
if (element.ExpireAt < timeNow || element.DownloadsRemaining < 1 || !fileExists) && !downloadstatus.IsCurrentlyDownloading(element) {
settings := configuration.GetServerSettings()
for key, element := range settings.Files {
fileExists := helper.FileExists(settings.DataDir + "/" + element.SHA256)
if (element.ExpireAt < timeNow || element.DownloadsRemaining < 1 || !fileExists) && !downloadstatus.IsCurrentlyDownloading(element, settings) {
deleteFile := true
for _, secondLoopElement := range configuration.ServerSettings.Files {
for _, secondLoopElement := range settings.Files {
if element.Id != secondLoopElement.Id && element.SHA256 == secondLoopElement.SHA256 {
deleteFile = false
}
}
if deleteFile && fileExists {
err := os.Remove(configuration.ServerSettings.DataDir + "/" + element.SHA256)
err := os.Remove(settings.DataDir + "/" + element.SHA256)
if err != nil {
fmt.Println(err)
}
}
if element.HotlinkId != "" {
delete(configuration.ServerSettings.Hotlinks, element.HotlinkId)
delete(settings.Hotlinks, element.HotlinkId)
}
delete(configuration.ServerSettings.Files, key)
delete(settings.Files, key)
wasItemDeleted = true
}
}
configuration.Release()
if wasItemDeleted {
configuration.Save()
CleanUp(false)
@@ -163,3 +172,13 @@ func CleanUp(periodic bool) {
go CleanUp(periodic)
}
}
// DeleteFile is called when an admin requests deletion of a file
func DeleteFile(keyId string) {
settings := configuration.GetServerSettings()
item := settings.Files[keyId]
item.ExpireAt = 0
settings.Files[keyId] = item
configuration.Release()
CleanUp(false)
}

View File

@@ -61,8 +61,10 @@ func TestAddHotlink(t *testing.T) {
test.IsEqualInt(t, len(file.HotlinkId), 44)
lastCharacters := file.HotlinkId[len(file.HotlinkId)-4:]
test.IsEqualBool(t, lastCharacters == ".jpg", true)
test.IsEqualString(t, configuration.ServerSettings.Hotlinks[file.HotlinkId].FileId, "testId")
test.IsEqualString(t, configuration.ServerSettings.Hotlinks[file.HotlinkId].Id, file.HotlinkId)
settings := configuration.GetServerSettings()
test.IsEqualString(t, settings.Hotlinks[file.HotlinkId].FileId, "testId")
test.IsEqualString(t, settings.Hotlinks[file.HotlinkId].Id, file.HotlinkId)
configuration.Release()
}
func TestNewFile(t *testing.T) {
@@ -108,66 +110,80 @@ func TestServeFile(t *testing.T) {
}
func TestCleanUp(t *testing.T) {
test.IsEqualString(t, configuration.ServerSettings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualString(t, configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "picture.jpg")
test.IsEqualString(t, configuration.ServerSettings.Files["deletedfile123456789"].Name, "DeletedFile")
settings := configuration.GetServerSettings()
configuration.Release()
test.IsEqualString(t, settings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualString(t, settings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "picture.jpg")
test.IsEqualString(t, settings.Files["deletedfile123456789"].Name, "DeletedFile")
test.IsEqualBool(t, helper.FileExists("test/data/2341354656543213246465465465432456898794"), true)
CleanUp(false)
test.IsEqualString(t, configuration.ServerSettings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualString(t, settings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualBool(t, helper.FileExists("test/data/2341354656543213246465465465432456898794"), true)
test.IsEqualString(t, configuration.ServerSettings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "picture.jpg")
test.IsEqualString(t, settings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, settings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "picture.jpg")
file, _ := GetFile("n1tSTAGj8zan9KaT4u6p")
file.DownloadsRemaining = 0
configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"] = file
settings.Files["n1tSTAGj8zan9KaT4u6p"] = file
CleanUp(false)
test.IsEqualBool(t, helper.FileExists("test/data/a8fdc205a9f19cc1c7507a60c4f01b13d11d7fd0"), false)
test.IsEqualString(t, configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, settings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, settings.Files["Wzol7LyY2QVczXynJtVo"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
file, _ = GetFile("Wzol7LyY2QVczXynJtVo")
file.DownloadsRemaining = 0
configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"] = file
settings.Files["Wzol7LyY2QVczXynJtVo"] = file
CleanUp(false)
test.IsEqualBool(t, helper.FileExists("test/data/e017693e4a04a59d0b0f400fe98177fe7ee13cf7"), true)
test.IsEqualString(t, configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
test.IsEqualString(t, settings.Files["Wzol7LyY2QVczXynJtVo"].Name, "")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, settings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, settings.Files["e4TjE7CokWK0giiLNxDL"].Name, "smallfile2")
test.IsEqualString(t, settings.Files["wefffewhtrhhtrhtrhtr"].Name, "smallfile3")
file, _ = GetFile("e4TjE7CokWK0giiLNxDL")
file.DownloadsRemaining = 0
configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"] = file
settings.Files["e4TjE7CokWK0giiLNxDL"] = file
file, _ = GetFile("wefffewhtrhhtrhtrhtr")
file.DownloadsRemaining = 0
configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"] = file
settings.Files["wefffewhtrhhtrhtrhtr"] = file
CleanUp(false)
test.IsEqualBool(t, helper.FileExists("test/data/e017693e4a04a59d0b0f400fe98177fe7ee13cf7"), false)
test.IsEqualString(t, configuration.ServerSettings.Files["Wzol7LyY2QVczXynJtVo"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["e4TjE7CokWK0giiLNxDL"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["wefffewhtrhhtrhtrhtr"].Name, "")
test.IsEqualString(t, settings.Files["Wzol7LyY2QVczXynJtVo"].Name, "")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualString(t, settings.Files["deletedfile123456789"].Name, "")
test.IsEqualString(t, settings.Files["e4TjE7CokWK0giiLNxDL"].Name, "")
test.IsEqualString(t, settings.Files["wefffewhtrhhtrhtrhtr"].Name, "")
test.IsEqualString(t, configuration.ServerSettings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualString(t, settings.Files["cleanuptest123456789"].Name, "cleanup")
test.IsEqualBool(t, helper.FileExists("test/data/2341354656543213246465465465432456898794"), true)
configuration.ServerSettings.DownloadStatus = make(map[string]models.DownloadStatus)
settings.DownloadStatus = make(map[string]models.DownloadStatus)
CleanUp(false)
test.IsEqualString(t, configuration.ServerSettings.Files["cleanuptest123456789"].Name, "")
test.IsEqualString(t, settings.Files["cleanuptest123456789"].Name, "")
test.IsEqualBool(t, helper.FileExists("test/data/2341354656543213246465465465432456898794"), false)
}
func TestDeleteFile(t *testing.T) {
testconfiguration.Create(true)
configuration.Load()
settings := configuration.GetServerSettings()
configuration.Release()
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "picture.jpg")
test.IsEqualBool(t, helper.FileExists("test/data/a8fdc205a9f19cc1c7507a60c4f01b13d11d7fd0"), true)
DeleteFile("n1tSTAGj8zan9KaT4u6p")
test.IsEqualString(t, settings.Files["n1tSTAGj8zan9KaT4u6p"].Name, "")
test.IsEqualBool(t, helper.FileExists("test/data/a8fdc205a9f19cc1c7507a60c4f01b13d11d7fd0"), false)
}

View File

@@ -44,8 +44,17 @@ var imageExpiredPicture []byte
const expiredFile = "static/expired.png"
var (
webserverPort string
webserverExtUrl string
webserverRedirectUrl string
webserverAdminName string
WebserverAdminPassword string
)
// Start the webserver on the port set in the config
func Start() {
initLocalVariables()
initTemplates(templateFolderEmbedded)
webserverDir, _ := fs.Sub(staticFolderEmbedded, "web/static")
var err error
@@ -70,16 +79,26 @@ func Start() {
http.HandleFunc("/delete", deleteFile)
http.HandleFunc("/downloadFile", downloadFile)
http.HandleFunc("/forgotpw", forgotPassword)
fmt.Println("Binding webserver to " + configuration.ServerSettings.Port)
fmt.Println("Webserver can be accessed at " + configuration.ServerSettings.ServerUrl + "admin")
fmt.Println("Binding webserver to " + webserverPort)
fmt.Println("Webserver can be accessed at " + webserverExtUrl + "admin")
srv := &http.Server{
Addr: configuration.ServerSettings.Port,
Addr: webserverPort,
ReadTimeout: timeOutWebserver,
WriteTimeout: timeOutWebserver,
}
log.Fatal(srv.ListenAndServe())
}
func initLocalVariables() {
settings := configuration.GetServerSettings()
webserverPort = settings.Port
webserverExtUrl = settings.ServerUrl
webserverRedirectUrl = settings.RedirectUrl
webserverAdminName = settings.AdminName
WebserverAdminPassword = settings.AdminPassword
configuration.Release()
}
// Initialises the templateFolder variable by scanning through all the templates.
// If a folder "templates" exists in the main directory, it is used.
// Otherwise templateFolderEmbedded will be used.
@@ -108,7 +127,7 @@ func doLogout(w http.ResponseWriter, r *http.Request) {
// Handling of /index and redirecting to globalConfig.RedirectUrl
func showIndex(w http.ResponseWriter, r *http.Request) {
err := templateFolder.ExecuteTemplate(w, "index", genericView{RedirectUrl: configuration.ServerSettings.RedirectUrl})
err := templateFolder.ExecuteTemplate(w, "index", genericView{RedirectUrl: webserverRedirectUrl})
helper.Check(err)
}
@@ -134,7 +153,7 @@ func showLogin(w http.ResponseWriter, r *http.Request) {
pw := r.Form.Get("password")
failedLogin := false
if pw != "" && user != "" {
if strings.ToLower(user) == strings.ToLower(configuration.ServerSettings.AdminName) && configuration.HashPassword(pw, false) == configuration.ServerSettings.AdminPassword {
if strings.ToLower(user) == strings.ToLower(webserverAdminName) && configuration.HashPassword(pw, false) == WebserverAdminPassword {
sessionmanager.CreateSession(w, false)
redirect(w, "admin")
return
@@ -213,7 +232,7 @@ func showHotlink(w http.ResponseWriter, r *http.Request) {
}
// Handling of /delete
// User needs to be admin. Deleted the requested file
// User needs to be admin. Deletes the requested file
func deleteFile(w http.ResponseWriter, r *http.Request) {
if !isAuthenticated(w, r, false) {
return
@@ -222,10 +241,7 @@ func deleteFile(w http.ResponseWriter, r *http.Request) {
if keyId == "" {
return
}
item := configuration.ServerSettings.Files[keyId]
item.ExpireAt = 0
configuration.ServerSettings.Files[keyId] = item
storage.CleanUp(false)
storage.DeleteFile(keyId)
redirect(w, "admin")
}
@@ -233,7 +249,7 @@ func deleteFile(w http.ResponseWriter, r *http.Request) {
// Stops for 500ms to limit brute forcing if invalid key and redirects to redirectUrl
func queryUrl(w http.ResponseWriter, r *http.Request, redirectUrl string) string {
keys, ok := r.URL.Query()["id"]
if !ok || len(keys[0]) < configuration.ServerSettings.LengthId {
if !ok || len(keys[0]) < configuration.GetLengthId() {
time.Sleep(500 * time.Millisecond)
redirect(w, redirectUrl)
return ""
@@ -278,7 +294,8 @@ type UploadView struct {
// the admin template
func (u *UploadView) convertGlobalConfig() *UploadView {
var result []models.File
for _, element := range configuration.ServerSettings.Files {
settings := configuration.GetServerSettings()
for _, element := range settings.Files {
result = append(result, element)
}
sort.Slice(result[:], func(i, j int) bool {
@@ -287,15 +304,16 @@ func (u *UploadView) convertGlobalConfig() *UploadView {
}
return result[i].ExpireAt > result[j].ExpireAt
})
u.Url = configuration.ServerSettings.ServerUrl + "d?id="
u.HotlinkUrl = configuration.ServerSettings.ServerUrl + "hotlink/"
u.DefaultPassword = configuration.ServerSettings.DefaultPassword
u.Url = settings.ServerUrl + "d?id="
u.HotlinkUrl = settings.ServerUrl + "hotlink/"
u.DefaultPassword = settings.DefaultPassword
u.Items = result
u.DefaultExpiry = configuration.ServerSettings.DefaultExpiry
u.DefaultDownloads = configuration.ServerSettings.DefaultDownloads
u.DefaultExpiry = settings.DefaultExpiry
u.DefaultDownloads = settings.DefaultDownloads
u.TimeNow = time.Now().Unix()
u.IsAdminView = true
u.IsMainView = true
configuration.Release()
return u
}
@@ -312,22 +330,24 @@ func uploadFile(w http.ResponseWriter, r *http.Request) {
expiryDays := r.Form.Get("expiryDays")
password := r.Form.Get("password")
allowedDownloadsInt, err := strconv.Atoi(allowedDownloads)
settings := configuration.GetServerSettings()
if err != nil {
allowedDownloadsInt = configuration.ServerSettings.DefaultDownloads
allowedDownloadsInt = settings.DefaultDownloads
}
expiryDaysInt, err := strconv.Atoi(expiryDays)
if err != nil {
expiryDaysInt = configuration.ServerSettings.DefaultExpiry
expiryDaysInt = settings.DefaultExpiry
}
configuration.ServerSettings.DefaultExpiry = expiryDaysInt
configuration.ServerSettings.DefaultDownloads = allowedDownloadsInt
configuration.ServerSettings.DefaultPassword = password
settings.DefaultExpiry = expiryDaysInt
settings.DefaultDownloads = allowedDownloadsInt
settings.DefaultPassword = password
configuration.Release()
file, header, err := r.FormFile("file")
responseError(w, err)
result, err := storage.NewFile(file, header, time.Now().Add(time.Duration(expiryDaysInt)*time.Hour*24).Unix(), allowedDownloadsInt, password)
responseError(w, err)
defer file.Close()
_, err = fmt.Fprint(w, result.ToJsonResult(configuration.ServerSettings.ServerUrl))
_, err = fmt.Fprint(w, result.ToJsonResult(webserverExtUrl))
helper.Check(err)
}

View File

@@ -46,11 +46,13 @@ func TestIndexRedirect(t *testing.T) {
}
func TestIndexFile(t *testing.T) {
t.Parallel()
settings := configuration.GetServerSettings()
testconfiguration.HttpPageResult(t, testconfiguration.HttpTestConfig{
Url: "http://localhost:53843/index",
RequiredContent: []string{configuration.ServerSettings.RedirectUrl},
RequiredContent: []string{settings.RedirectUrl},
IsHtml: true,
})
configuration.Release()
}
func TestStaticDirs(t *testing.T) {
t.Parallel()

View File

@@ -9,12 +9,15 @@ import (
"Gokapi/internal/helper"
"Gokapi/internal/models"
"net/http"
"sync"
"time"
)
// If no login occurred during this time, the admin session will be deleted. Default 30 days
const cookieLifeAdmin = 30 * 24 * time.Hour
var mutex sync.Mutex
// IsValidSession checks if the user is submitting a valid session token
// If valid session is found, useSession will be called
// Returns true if authenticated, otherwise false
@@ -23,9 +26,10 @@ func IsValidSession(w http.ResponseWriter, r *http.Request) bool {
if err == nil {
sessionString := cookie.Value
if sessionString != "" {
configuration.LockSessions()
defer func() { configuration.UnlockSessionsAndSave() }()
_, ok := configuration.ServerSettings.Sessions[sessionString]
mutex.Lock()
sessions := configuration.GetSessions()
defer func() { unlockAndSave() }()
_, ok := (*sessions)[sessionString]
if ok {
return useSession(w, sessionString)
}
@@ -34,19 +38,25 @@ func IsValidSession(w http.ResponseWriter, r *http.Request) bool {
return false
}
func unlockAndSave() {
configuration.Save()
mutex.Unlock()
}
// useSession checks if a session is still valid. It Changes the session string
// if it has // been used for more than an hour to limit session hijacking
// Returns true if session is still valid
// Returns false if session is invalid (and deletes it)
func useSession(w http.ResponseWriter, sessionString string) bool {
session := configuration.ServerSettings.Sessions[sessionString]
sessions := configuration.GetSessions()
session := (*sessions)[sessionString]
if session.ValidUntil < time.Now().Unix() {
delete(configuration.ServerSettings.Sessions, sessionString)
delete(*sessions, sessionString)
return false
}
if session.RenewAt < time.Now().Unix() {
CreateSession(w, true)
delete(configuration.ServerSettings.Sessions, sessionString)
delete(*sessions, sessionString)
}
return true
}
@@ -54,11 +64,12 @@ func useSession(w http.ResponseWriter, sessionString string) bool {
// CreateSession creates a new session - called after login with correct username / password
func CreateSession(w http.ResponseWriter, isLocked bool) {
if !isLocked {
configuration.LockSessions()
defer func() { configuration.UnlockSessionsAndSave() }()
mutex.Lock()
defer func() { unlockAndSave() }()
}
sessionString := helper.GenerateRandomString(60)
configuration.ServerSettings.Sessions[sessionString] = models.Session{
sessions := configuration.GetSessions()
(*sessions)[sessionString] = models.Session{
RenewAt: time.Now().Add(time.Hour).Unix(),
ValidUntil: time.Now().Add(cookieLifeAdmin).Unix(),
}
@@ -69,9 +80,10 @@ func CreateSession(w http.ResponseWriter, isLocked bool) {
func LogoutSession(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_token")
if err == nil {
configuration.LockSessions()
delete(configuration.ServerSettings.Sessions, cookie.Value)
configuration.UnlockSessionsAndSave()
mutex.Lock()
sessions := configuration.GetSessions()
delete(*sessions, cookie.Value)
unlockAndSave()
}
writeSessionCookie(w, "", time.Now())
}