diff --git a/gofsutil.go b/gofsutil.go index ed58e92..10045cd 100644 --- a/gofsutil.go +++ b/gofsutil.go @@ -86,7 +86,7 @@ var ( ErrNotImplemented = errors.New("not implemented") // fs is the default FS instance. - fs FSinterface = &FS{ScanEntry: defaultEntryScanFunc, SysBlockDir: "/sys/block"} + fs FSinterface = &FS{ScanEntry: defaultEntryScanFunc} ) // ContextKey is a variable containing context-keys @@ -102,11 +102,6 @@ func UseMockFS() { fs = &mockfs{ScanEntry: defaultEntryScanFunc} } -// UseMockSysBlockDir creates a file system for testing. -func UseMockSysBlockDir(mockSysBlockDir string) { - fs = &FS{ScanEntry: defaultEntryScanFunc, SysBlockDir: mockSysBlockDir} -} - // GetDiskFormat uses 'lsblk' to see if the given disk is unformatted. func GetDiskFormat(ctx context.Context, disk string) (string, error) { return fs.GetDiskFormat(ctx, disk) diff --git a/gofsutil_fs.go b/gofsutil_fs.go index 88a4667..89ed783 100644 --- a/gofsutil_fs.go +++ b/gofsutil_fs.go @@ -19,12 +19,13 @@ import ( "golang.org/x/sys/unix" ) +// SysBlockDir is used to set the directory of block devices. +var SysBlockDir string = "/sys/block" + // FS provides many filesystem-specific functions, such as mount, format, etc. type FS struct { // ScanEntry is the function used to process mount table entries. ScanEntry EntryScanFunc - // SysBlockDir is used to set the directory of block devices. - SysBlockDir string } // GetDiskFormat uses 'lsblk' to see if the given disk is unformatted. diff --git a/gofsutil_mount_test.go b/gofsutil_mount_test.go index f783b88..a2f1939 100644 --- a/gofsutil_mount_test.go +++ b/gofsutil_mount_test.go @@ -89,7 +89,7 @@ func TestGetMounts(t *testing.T) { func TestGetSysBlockDevicesForVolumeWWN(t *testing.T) { tempDir := t.TempDir() - gofsutil.UseMockSysBlockDir(tempDir) + gofsutil.SysBlockDir = tempDir tests := []struct { name string @@ -153,7 +153,7 @@ func TestGetSysBlockDevicesForVolumeWWN(t *testing.T) { func TestGetNVMeController(t *testing.T) { tempDir := t.TempDir() - gofsutil.UseMockSysBlockDir(tempDir) + gofsutil.SysBlockDir = tempDir tests := map[string]struct { device string diff --git a/gofsutil_mount_unix.go b/gofsutil_mount_unix.go index 4b4f116..bb495b0 100644 --- a/gofsutil_mount_unix.go +++ b/gofsutil_mount_unix.go @@ -650,9 +650,9 @@ func (fs *FS) issueLIPToAllFCHosts(_ context.Context) error { func (fs *FS) getSysBlockDevicesForVolumeWWN(_ context.Context, volumeWWN string) ([]string, error) { start := time.Now() result := make([]string, 0) - sysBlocks, err := os.ReadDir(fs.SysBlockDir) + sysBlocks, err := os.ReadDir(SysBlockDir) if err != nil { - return result, fmt.Errorf("Error reading %s: %s", fs.SysBlockDir, err) + return result, fmt.Errorf("Error reading %s: %s", SysBlockDir, err) } for _, sysBlock := range sysBlocks { @@ -665,9 +665,9 @@ func (fs *FS) getSysBlockDevicesForVolumeWWN(_ context.Context, volumeWWN string // Set the WWID path based on the device type var wwidPath string if strings.HasPrefix(name, "nvme") { - wwidPath = fs.SysBlockDir + "/" + name + "/wwid" // For NVMe devices + wwidPath = SysBlockDir + "/" + name + "/wwid" // For NVMe devices } else { - wwidPath = fs.SysBlockDir + "/" + name + "/device/wwid" // For SCSI devices + wwidPath = SysBlockDir + "/" + name + "/device/wwid" // For SCSI devices } bytes, err := os.ReadFile(filepath.Clean(wwidPath)) @@ -746,7 +746,7 @@ func wwnMatches(nguid, wwn string) bool { // GetNVMeController retrieves the NVMe controller for a given NVMe device. func (fs *FS) getNVMeController(device string) (string, error) { - devicePath := filepath.Join(fs.SysBlockDir, device) + devicePath := filepath.Join(SysBlockDir, device) // Check if the device path exists if _, err := os.Stat(devicePath); os.IsNotExist(err) { diff --git a/gofsutil_unix_test.go b/gofsutil_unix_test.go index 938dc96..d4f95f1 100644 --- a/gofsutil_unix_test.go +++ b/gofsutil_unix_test.go @@ -132,6 +132,9 @@ func TestMountArgs(t *testing.T) { } func TestWWNToDevicePath(t *testing.T) { + tempDir := t.TempDir() + SysBlockDir = tempDir + tests := []struct { src string tgt string @@ -251,7 +254,7 @@ func TestValidateMountArgs(t *testing.T) { for _, tt := range tests { t.Run(tt.testname, func(t *testing.T) { - fs := FS{SysBlockDir: "string"} + fs := FS{} err := fs.validateMountArgs(tt.source, tt.target, tt.fstype, tt.opts...) assert.Equal(t, tt.expect, err) }) @@ -291,7 +294,7 @@ func TestDoMount(t *testing.T) { for _, tt := range tests { t.Run(tt.testname, func(t *testing.T) { - fs := FS{SysBlockDir: "string"} + fs := FS{} err := fs.doMount(tt.ctx, tt.mntCmnd, tt.source, tt.target, tt.fstype, tt.opts...) assert.Equal(t, true, strings.Contains(err.Error(), tt.expect)) }) @@ -319,7 +322,7 @@ func TestUnMount(t *testing.T) { for _, tt := range tests { t.Run(tt.testname, func(t *testing.T) { - fs := FS{SysBlockDir: "string"} + fs := FS{} err := fs.unmount(tt.ctx, tt.target) assert.Equal(t, true, strings.Contains(err.Error(), tt.expect)) }) @@ -398,7 +401,7 @@ func TestMultipathCommand(t *testing.T) { } for _, tt := range tests { t.Run(tt.testname, func(t *testing.T) { - fs := FS{SysBlockDir: "string"} + fs := FS{} _, err := fs.multipathCommand(tt.ctx, tt.timeoutSeconds, tt.chroot, tt.arguments...) assert.Equal(t, tt.expectErr, err) }) @@ -421,7 +424,7 @@ func TestIsBind(t *testing.T) { for _, tt := range tests { t.Run(tt.testname, func(t *testing.T) { - fs := FS{SysBlockDir: "string"} + fs := FS{} _, err := fs.isBind(tt.ctx, tt.opts...) assert.Equal(t, tt.expect, err) })