diff --git a/go.mod b/go.mod index 8bd3967..5dd5fbd 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/NVIDIA/go-nvlib go 1.20 require ( - github.com/NVIDIA/go-nvml v0.12.0-5 + github.com/NVIDIA/go-nvml v0.12.0-6 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index 732111e..0b0a7c5 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/NVIDIA/go-nvml v0.12.0-5 h1:4DYsngBqJEAEj+/RFmBZ43Q3ymoR3tyS0oBuJk12Fag= -github.com/NVIDIA/go-nvml v0.12.0-5/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= +github.com/NVIDIA/go-nvml v0.12.0-6 h1:FJYc2KrpvX+VOC/8QQvMiQMmZ/nPMRpdJO/Ik4xfcr0= +github.com/NVIDIA/go-nvml v0.12.0-6/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go index 7ee5e55..7604d39 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go @@ -15,9 +15,48 @@ package nvml import ( + "fmt" + "reflect" "unsafe" ) +// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice. +// This is required for functions such as GetTopologyCommonAncestor which +// accept Device arguments that need to be passed to internal nvml* functions +// as nvmlDevice parameters. +func nvmlDeviceHandle(d Device) nvmlDevice { + var helper func(val reflect.Value) nvmlDevice + helper = func(val reflect.Value) nvmlDevice { + if val.Kind() == reflect.Interface { + val = val.Elem() + } + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Type() == reflect.TypeOf(nvmlDevice{}) { + return val.Interface().(nvmlDevice) + } + + if val.Kind() != reflect.Struct { + panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind())) + } + + for i := 0; i < val.Type().NumField(); i++ { + if !val.Type().Field(i).Anonymous { + continue + } + if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) { + continue + } + return helper(val.Field(i)) + } + panic(fmt.Errorf("unable to convert %T to nvmlDevice", d)) + } + return helper(reflect.ValueOf(d)) +} + // EccBitType type EccBitType = MemoryErrorType @@ -220,10 +259,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) { var pathInfo GpuTopologyLevel - ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo) + ret := nvmlDeviceGetTopologyCommonAncestorStub(device1, nvmlDeviceHandle(device2), &pathInfo) return pathInfo, ret } +// nvmlDeviceGetTopologyCommonAncestorStub allows us to override this for testing. +var nvmlDeviceGetTopologyCommonAncestorStub = nvmlDeviceGetTopologyCommonAncestor + // nvml.DeviceGetTopologyNearestGpus() func (l *library) DeviceGetTopologyNearestGpus(device Device, level GpuTopologyLevel) ([]Device, Return) { return device.GetTopologyNearestGpus(level) @@ -250,7 +292,7 @@ func (l *library) DeviceGetP2PStatus(device1 Device, device2 Device, p2pIndex Gp func (device1 nvmlDevice) GetP2PStatus(device2 Device, p2pIndex GpuP2PCapsIndex) (GpuP2PStatus, Return) { var p2pStatus GpuP2PStatus - ret := nvmlDeviceGetP2PStatus(device1, device2.(nvmlDevice), p2pIndex, &p2pStatus) + ret := nvmlDeviceGetP2PStatus(device1, nvmlDeviceHandle(device2), p2pIndex, &p2pStatus) return p2pStatus, ret } @@ -1182,7 +1224,7 @@ func (l *library) DeviceOnSameBoard(device1 Device, device2 Device) (int, Return func (device1 nvmlDevice) OnSameBoard(device2 Device) (int, Return) { var onSameBoard int32 - ret := nvmlDeviceOnSameBoard(device1, device2.(nvmlDevice), &onSameBoard) + ret := nvmlDeviceOnSameBoard(device1, nvmlDeviceHandle(device2), &onSameBoard) return int(onSameBoard), ret } diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go index 783514b..acdb2e0 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go @@ -93,7 +93,7 @@ func (device nvmlDevice) GpmSampleGet(gpmSample GpmSample) Return { } func (gpmSample nvmlGpmSample) Get(device Device) Return { - return nvmlGpmSampleGet(device.(nvmlDevice), gpmSample) + return nvmlGpmSampleGet(nvmlDeviceHandle(device), gpmSample) } // nvml.GpmQueryDeviceSupport() @@ -137,5 +137,5 @@ func (device nvmlDevice) GpmMigSampleGet(gpuInstanceId int, gpmSample GpmSample) } func (gpmSample nvmlGpmSample) MigGet(device Device, gpuInstanceId int) Return { - return nvmlGpmMigSampleGet(device.(nvmlDevice), uint32(gpuInstanceId), gpmSample) + return nvmlGpmMigSampleGet(nvmlDeviceHandle(device), uint32(gpuInstanceId), gpmSample) } diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go index bd80077..da49524 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go @@ -142,7 +142,7 @@ func (device nvmlDevice) VgpuTypeGetMaxInstances(vgpuTypeId VgpuTypeId) (int, Re func (vgpuTypeId nvmlVgpuTypeId) GetMaxInstances(device Device) (int, Return) { var vgpuInstanceCount uint32 - ret := nvmlVgpuTypeGetMaxInstances(device.(nvmlDevice), vgpuTypeId, &vgpuInstanceCount) + ret := nvmlVgpuTypeGetMaxInstances(nvmlDeviceHandle(device), vgpuTypeId, &vgpuInstanceCount) return int(vgpuInstanceCount), ret } diff --git a/vendor/modules.txt b/vendor/modules.txt index 627aae9..f6a1e68 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/NVIDIA/go-nvml v0.12.0-5 +# github.com/NVIDIA/go-nvml v0.12.0-6 ## explicit; go 1.20 github.com/NVIDIA/go-nvml/pkg/dl github.com/NVIDIA/go-nvml/pkg/nvml