From 6497f4fa3e8c0c5e1da0fc879925fb540f725e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Eduardo=20Jer=C3=A9z=20Gir=C3=B3n?= Date: Sat, 3 Aug 2024 01:11:49 -0600 Subject: [PATCH] Add RestoreZip method to Client in postgres.go --- internal/integration/postgres/postgres.go | 81 +++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/internal/integration/postgres/postgres.go b/internal/integration/postgres/postgres.go index f50784b..daf263b 100644 --- a/internal/integration/postgres/postgres.go +++ b/internal/integration/postgres/postgres.go @@ -5,6 +5,8 @@ import ( "bytes" "fmt" "io" + "net/http" + "os" "os/exec" "github.com/orsinium-labs/enum" @@ -195,3 +197,82 @@ func (c *Client) DumpZip( return reader } + +// RestoreZip downloads the ZIP from the given url, unzips it, and runs the +// psql command to restore the database. +func (Client) RestoreZip( + version PGVersion, connString string, zipURL string, +) error { + // Create a temporary directory + dir, err := os.MkdirTemp("", "pbw-restore-*") + if err != nil { + return fmt.Errorf("error creating temp dir: %w", err) + } + defer os.RemoveAll(dir) + + // Download the ZIP file from the given URL + zipPath := fmt.Sprintf("%s/dump.zip", dir) + resp, err := http.Get(zipURL) + if err != nil { + return fmt.Errorf("error downloading ZIP file: %w", err) + } + defer resp.Body.Close() + + out, err := os.Create(zipPath) + if err != nil { + return fmt.Errorf("error creating ZIP file: %w", err) + } + defer out.Close() + + if _, err = io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("error writing to ZIP file: %w", err) + } + + // Unzip the file into the temp dir + zipReadCloser, err := zip.OpenReader(zipPath) + if err != nil { + return fmt.Errorf("error opening ZIP file: %w", err) + } + defer zipReadCloser.Close() + + var dumpPath string + for _, file := range zipReadCloser.File { + if file.Name == "dump.sql" { + dumpPath = fmt.Sprintf("%s/%s", dir, file.Name) + + fileReadCloser, err := file.Open() + if err != nil { + return fmt.Errorf("error opening dump.sql in ZIP file: %w", err) + } + defer fileReadCloser.Close() + + outFile, err := os.Create(dumpPath) + if err != nil { + return fmt.Errorf("error creating dump.sql: %w", err) + } + defer outFile.Close() + + if _, err = io.Copy(outFile, fileReadCloser); err != nil { + return fmt.Errorf("error writing dump.sql: %w", err) + } + + break + } + } + + if dumpPath == "" { + return fmt.Errorf("dump.sql not found in ZIP file") + } + + // Run the psql command to restore the database + cmd := exec.Command(version.Value.psql, connString, "-f", dumpPath) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf( + "error running psql v%s command: %s", + version.Value.version, output, + ) + } + + return nil +}