Skip to content

Commit 54a6dcf

Browse files
authored
Add sync.Mutex and os.Rename to prevent corrupted file when downloading the Postgres archive (#105)
* Add sync.Mutex to prevent collisions * * add atomic download, and use defer to ensure mutex unlock * * move mutex to global * * fix tests * * update platform-test * * update examples * * remove atomic dependency * * remove code duplication * * reduce test parallel run count * * fix race condition in tests * * run tests * * attempt to fix windows * * attempt a different solution for windows * * revert changes * * try additional fix for windows * * add another test * * catch syscall.EEXIST * * fix test * * fix test * * add extra debugging * * attempt to fix windows * * add additional error message * * fix race in decompression * * more fixes * * use atomic * * add extra debug * * try catching the error * * try different permissions * * add more debugging * * more debug * * more debug * * test dest * * attempt to close temp file * * simplify * * remove atomic * * clean up code * * add more tests * * clean up temporary files * * prevent file being opened twice
1 parent cd07cf7 commit 54a6dcf

9 files changed

+343
-48
lines changed

decompression.go

+38-7
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,19 @@ func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func()
2121
}
2222

2323
func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error {
24+
tempExtractPath, err := os.MkdirTemp(filepath.Dir(extractPath), "temp_")
25+
if err != nil {
26+
return errorUnableToExtract(path, extractPath, err)
27+
}
28+
defer func() {
29+
if err := os.RemoveAll(tempExtractPath); err != nil {
30+
panic(err)
31+
}
32+
}()
33+
2434
tarFile, err := os.Open(path)
2535
if err != nil {
26-
return errorUnableToExtract(path, extractPath)
36+
return errorUnableToExtract(path, extractPath, err)
2737
}
2838

2939
defer func() {
@@ -34,7 +44,7 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu
3444

3545
xzReader, err := xz.NewReader(tarFile, 0)
3646
if err != nil {
37-
return errorUnableToExtract(path, extractPath)
47+
return errorUnableToExtract(path, extractPath, err)
3848
}
3949

4050
readNext, reader := tarReader(xzReader)
@@ -43,16 +53,21 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu
4353
header, err := readNext()
4454

4555
if err == io.EOF {
46-
return nil
56+
break
4757
}
4858

4959
if err != nil {
5060
return errorExtractingPostgres(err)
5161
}
5262

53-
targetPath := filepath.Join(extractPath, header.Name)
63+
targetPath := filepath.Join(tempExtractPath, header.Name)
64+
finalPath := filepath.Join(extractPath, header.Name)
5465

55-
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
66+
if err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm); err != nil {
67+
return errorExtractingPostgres(err)
68+
}
69+
70+
if err := os.MkdirAll(filepath.Dir(finalPath), os.ModePerm); err != nil {
5671
return errorExtractingPostgres(err)
5772
}
5873

@@ -78,10 +93,26 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu
7893
if err := os.Symlink(header.Linkname, targetPath); err != nil {
7994
return errorExtractingPostgres(err)
8095
}
96+
97+
case tar.TypeDir:
98+
if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil {
99+
return errorExtractingPostgres(err)
100+
}
101+
continue
102+
}
103+
104+
if err := renameOrIgnore(targetPath, finalPath); err != nil {
105+
return errorExtractingPostgres(err)
81106
}
82107
}
108+
109+
return nil
83110
}
84111

85-
func errorUnableToExtract(cacheLocation, binariesPath string) error {
86-
return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories", cacheLocation, binariesPath)
112+
func errorUnableToExtract(cacheLocation, binariesPath string, err error) error {
113+
return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w",
114+
cacheLocation,
115+
binariesPath,
116+
err,
117+
)
87118
}

decompression_test.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ package embeddedpostgres
33
import (
44
"archive/tar"
55
"errors"
6+
"fmt"
67
"io"
78
"os"
9+
"path"
810
"path/filepath"
11+
"syscall"
912
"testing"
1013

1114
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
1216
"github.com/xi2/xz"
1317
)
1418

@@ -17,6 +21,9 @@ func Test_decompressTarXz(t *testing.T) {
1721
if err != nil {
1822
panic(err)
1923
}
24+
if err := syscall.Rmdir(tempDir); err != nil {
25+
panic(err)
26+
}
2027

2128
archive, cleanUp := createTempXzArchive()
2229
defer cleanUp()
@@ -37,14 +44,22 @@ func Test_decompressTarXz(t *testing.T) {
3744
func Test_decompressTarXz_ErrorWhenFileNotExists(t *testing.T) {
3845
err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake")
3946

40-
assert.EqualError(t, err, "unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories")
47+
assert.Error(t, err)
48+
assert.Contains(
49+
t,
50+
err.Error(),
51+
"unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories",
52+
)
4153
}
4254

4355
func Test_decompressTarXz_ErrorWhenErrorDuringRead(t *testing.T) {
4456
tempDir, err := os.MkdirTemp("", "temp_tar_test")
4557
if err != nil {
4658
panic(err)
4759
}
60+
if err := syscall.Rmdir(tempDir); err != nil {
61+
panic(err)
62+
}
4863

4964
archive, cleanUp := createTempXzArchive()
5065
defer cleanUp()
@@ -103,6 +118,9 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) {
103118
if err != nil {
104119
panic(err)
105120
}
121+
if err := syscall.Rmdir(tempDir); err != nil {
122+
panic(err)
123+
}
106124

107125
archive, cleanUp := createTempXzArchive()
108126
defer cleanUp()
@@ -137,6 +155,9 @@ func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
137155
if err != nil {
138156
panic(err)
139157
}
158+
if err := syscall.Rmdir(tempDir); err != nil {
159+
panic(err)
160+
}
140161

141162
archive, cleanup := createTempXzArchive()
142163

@@ -163,3 +184,23 @@ func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
163184

164185
assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt")
165186
}
187+
188+
func Test_decompressTarXz_ErrorWithInvalidDestination(t *testing.T) {
189+
archive, cleanUp := createTempXzArchive()
190+
defer cleanUp()
191+
192+
tempDir, err := os.MkdirTemp("", "temp_tar_test")
193+
require.NoError(t, err)
194+
defer func() {
195+
os.RemoveAll(tempDir)
196+
}()
197+
198+
op := fmt.Sprintf(path.Join(tempDir, "%c"), rune(0))
199+
200+
err = decompressTarXz(defaultTarReader, archive, op)
201+
assert.EqualError(
202+
t,
203+
err,
204+
fmt.Sprintf("unable to extract postgres archive: mkdir %s: invalid argument", op),
205+
)
206+
}

embedded_postgres.go

+26-12
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ import (
99
"path/filepath"
1010
"runtime"
1111
"strings"
12+
"sync"
1213
)
1314

15+
var mu sync.Mutex
16+
1417
// EmbeddedPostgres maintains all configuration and runtime functions for maintaining the lifecycle of one Postgres process.
1518
type EmbeddedPostgres struct {
1619
config Config
@@ -92,20 +95,11 @@ func (ep *EmbeddedPostgres) Start() error {
9295
ep.config.binariesPath = ep.config.runtimePath
9396
}
9497

95-
_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin"))
96-
if os.IsNotExist(binDirErr) {
97-
if !cacheExists {
98-
if err := ep.remoteFetchStrategy(); err != nil {
99-
return err
100-
}
101-
}
102-
103-
if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil {
104-
return err
105-
}
98+
if err := ep.downloadAndExtractBinary(cacheExists, cacheLocation); err != nil {
99+
return err
106100
}
107101

108-
if err := os.MkdirAll(ep.config.runtimePath, 0755); err != nil {
102+
if err := os.MkdirAll(ep.config.runtimePath, os.ModePerm); err != nil {
109103
return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err)
110104
}
111105

@@ -148,6 +142,26 @@ func (ep *EmbeddedPostgres) Start() error {
148142
return nil
149143
}
150144

145+
func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error {
146+
// lock to prevent collisions with duplicate downloads
147+
mu.Lock()
148+
defer mu.Unlock()
149+
150+
_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin"))
151+
if os.IsNotExist(binDirErr) {
152+
if !cacheExists {
153+
if err := ep.remoteFetchStrategy(); err != nil {
154+
return err
155+
}
156+
}
157+
158+
if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil {
159+
return err
160+
}
161+
}
162+
return nil
163+
}
164+
151165
func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
152166
if err := os.RemoveAll(ep.config.dataPath); err != nil {
153167
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)

embedded_postgres_test.go

+62-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"time"
1616

1717
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1819
)
1920

2021
func Test_DefaultConfig(t *testing.T) {
@@ -99,7 +100,7 @@ func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) {
99100
}
100101
}
101102

102-
assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted")))
103+
assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, xz: file format not recognized`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted")))
103104
}
104105

105106
func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
@@ -355,6 +356,66 @@ func Test_CustomLocaleConfig(t *testing.T) {
355356
}
356357
}
357358

359+
func Test_ConcurrentStart(t *testing.T) {
360+
var wg sync.WaitGroup
361+
362+
database := NewDatabase()
363+
cacheLocation, _ := database.cacheLocator()
364+
err := os.RemoveAll(cacheLocation)
365+
require.NoError(t, err)
366+
367+
port := 5432
368+
for i := 1; i <= 3; i++ {
369+
port = port + 1
370+
wg.Add(1)
371+
372+
go func(p int) {
373+
defer wg.Done()
374+
tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
375+
if err != nil {
376+
panic(err)
377+
}
378+
379+
defer func() {
380+
if err := os.RemoveAll(tempDir); err != nil {
381+
panic(err)
382+
}
383+
}()
384+
385+
database := NewDatabase(DefaultConfig().
386+
RuntimePath(tempDir).
387+
Port(uint32(p)))
388+
389+
if err := database.Start(); err != nil {
390+
shutdownDBAndFail(t, err, database)
391+
}
392+
393+
db, err := sql.Open(
394+
"postgres",
395+
fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", p),
396+
)
397+
if err != nil {
398+
shutdownDBAndFail(t, err, database)
399+
}
400+
401+
if err = db.Ping(); err != nil {
402+
shutdownDBAndFail(t, err, database)
403+
}
404+
405+
if err := db.Close(); err != nil {
406+
shutdownDBAndFail(t, err, database)
407+
}
408+
409+
if err := database.Stop(); err != nil {
410+
shutdownDBAndFail(t, err, database)
411+
}
412+
413+
}(port)
414+
}
415+
416+
wg.Wait()
417+
}
418+
358419
func Test_CanStartAndStopTwice(t *testing.T) {
359420
database := NewDatabase()
360421

platform-test/platform_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ func Test_AllMajorVersions(t *testing.T) {
7272
}
7373

7474
func shutdownDBAndFail(t *testing.T, err error, db *embeddedpostgres.EmbeddedPostgres, version embeddedpostgres.PostgresVersion) {
75-
if err := db.Stop(); err != nil {
76-
t.Fatalf("Failed for version %s with error %s", version, err)
75+
if err2 := db.Stop(); err2 != nil {
76+
t.Fatalf("Failed for version %s with error %s, original error %s", version, err2, err)
7777
}
7878

7979
t.Fatalf("Failed for version %s with error %s", version, err)

0 commit comments

Comments
 (0)