Add RestoreZip method to Client in postgres.go

This commit is contained in:
Luis Eduardo Jeréz Girón
2024-08-03 01:11:49 -06:00
parent dfdd26daa1
commit 6497f4fa3e
+81
View File
@@ -5,6 +5,8 @@ import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"github.com/orsinium-labs/enum"
@@ -195,3 +197,82 @@ func (c *Client) DumpZip(
return reader
}
// RestoreZip downloads the ZIP from the given url, unzips it, and runs the
// psql command to restore the database.
func (Client) RestoreZip(
version PGVersion, connString string, zipURL string,
) error {
// Create a temporary directory
dir, err := os.MkdirTemp("", "pbw-restore-*")
if err != nil {
return fmt.Errorf("error creating temp dir: %w", err)
}
defer os.RemoveAll(dir)
// Download the ZIP file from the given URL
zipPath := fmt.Sprintf("%s/dump.zip", dir)
resp, err := http.Get(zipURL)
if err != nil {
return fmt.Errorf("error downloading ZIP file: %w", err)
}
defer resp.Body.Close()
out, err := os.Create(zipPath)
if err != nil {
return fmt.Errorf("error creating ZIP file: %w", err)
}
defer out.Close()
if _, err = io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("error writing to ZIP file: %w", err)
}
// Unzip the file into the temp dir
zipReadCloser, err := zip.OpenReader(zipPath)
if err != nil {
return fmt.Errorf("error opening ZIP file: %w", err)
}
defer zipReadCloser.Close()
var dumpPath string
for _, file := range zipReadCloser.File {
if file.Name == "dump.sql" {
dumpPath = fmt.Sprintf("%s/%s", dir, file.Name)
fileReadCloser, err := file.Open()
if err != nil {
return fmt.Errorf("error opening dump.sql in ZIP file: %w", err)
}
defer fileReadCloser.Close()
outFile, err := os.Create(dumpPath)
if err != nil {
return fmt.Errorf("error creating dump.sql: %w", err)
}
defer outFile.Close()
if _, err = io.Copy(outFile, fileReadCloser); err != nil {
return fmt.Errorf("error writing dump.sql: %w", err)
}
break
}
}
if dumpPath == "" {
return fmt.Errorf("dump.sql not found in ZIP file")
}
// Run the psql command to restore the database
cmd := exec.Command(version.Value.psql, connString, "-f", dumpPath)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf(
"error running psql v%s command: %s",
version.Value.version, output,
)
}
return nil
}