diff --git a/src/testing/testing.go b/src/testing/testing.go index c3d4b4ba..9c76f097 100644 --- a/src/testing/testing.go +++ b/src/testing/testing.go @@ -117,6 +117,7 @@ type TB interface { Log(args ...interface{}) Logf(format string, args ...interface{}) Name() string + Setenv(key, value string) Skip(args ...interface{}) SkipNow() Skipf(format string, args ...interface{}) @@ -330,6 +331,27 @@ func (c *common) TempDir() string { return dir } +// Setenv calls os.Setenv(key, value) and uses Cleanup to +// restore the environment variable to its original value +// after the test. +func (c *common) Setenv(key, value string) { + prevValue, ok := os.LookupEnv(key) + + if err := os.Setenv(key, value); err != nil { + c.Fatalf("cannot set environment variable: %v", err) + } + + if ok { + c.Cleanup(func() { + os.Setenv(key, prevValue) + }) + } else { + c.Cleanup(func() { + os.Unsetenv(key) + }) + } +} + // runCleanup is called at the end of the test. func (c *common) runCleanup() { for { diff --git a/src/testing/testing_test.go b/src/testing/testing_test.go index b0b5ff06..8a258653 100644 --- a/src/testing/testing_test.go +++ b/src/testing/testing_test.go @@ -138,3 +138,60 @@ func testTempDir(t *testing.T) { t.Errorf("unexpected files in TempDir") } } + +func TestSetenv(t *testing.T) { + tests := []struct { + name string + key string + initialValueExists bool + initialValue string + newValue string + }{ + { + name: "initial value exists", + key: "GO_TEST_KEY_1", + initialValueExists: true, + initialValue: "111", + newValue: "222", + }, + { + name: "initial value exists but empty", + key: "GO_TEST_KEY_2", + initialValueExists: true, + initialValue: "", + newValue: "222", + }, + { + name: "initial value is not exists", + key: "GO_TEST_KEY_3", + initialValueExists: false, + initialValue: "", + newValue: "222", + }, + } + + for _, test := range tests { + if test.initialValueExists { + if err := os.Setenv(test.key, test.initialValue); err != nil { + t.Fatalf("unable to set env: got %v", err) + } + } else { + os.Unsetenv(test.key) + } + + t.Run(test.name, func(t *testing.T) { + t.Setenv(test.key, test.newValue) + if os.Getenv(test.key) != test.newValue { + t.Fatalf("unexpected value after t.Setenv: got %s, want %s", os.Getenv(test.key), test.newValue) + } + }) + + got, exists := os.LookupEnv(test.key) + if got != test.initialValue { + t.Fatalf("unexpected value after t.Setenv cleanup: got %s, want %s", got, test.initialValue) + } + if exists != test.initialValueExists { + t.Fatalf("unexpected value after t.Setenv cleanup: got %t, want %t", exists, test.initialValueExists) + } + } +}