Skip to content

Commit

Permalink
Version 0.4 incorporating new pytorch's linalg for version >1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanFMontesinos committed Jul 5, 2021
1 parent 0b46b65 commit af324dd
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ Nvidia RTX 3090, ADM Threadripper 1920X for a single run


```
.bss_eval_sources test> permutation: False float32 CPU: 3.004 torch-CPU: 2.512
.bss_eval_sources test> Compute permutation: False float32 CPU: 2.590 GPU: 1.616
.bss_eval_sources test> compute_permutation: True float32 CPU: 13.195 torch-CPU: 12.326
.bss_eval_sources test> compute_permutation: True float32 CPU: 13.407 GPU: 7.779
.bss_eval_sources test> Compute permutation: False float64 CPU: 2.615 torch-CPU: 4.612
.bss_eval_sources test> Compute permutation: False float64 CPU: 2.578 GPU: 1.941
.bss_eval_sources test> Compute permutation: False float64 CPU: 12.949 torch-CPU: 22.422
.bss_eval_sources test> Compute permutation: False float64 CPU: 13.142 GPU: 9.404
.bss_eval_sources test> permutation: False float32 CPU: 2.582 torch-CPU: 1.670
.bss_eval_sources test> Compute permutation: False float32 CPU: 2.632 GPU: 1.286
.bss_eval_sources test> compute_permutation: True float32 CPU: 12.877 torch-CPU: 8.289
.bss_eval_sources test> compute_permutation: True float32 CPU: 12.768 GPU: 6.107
.bss_eval_sources test> Compute permutation: False float64 CPU: 2.576 torch-CPU: 2.793
.bss_eval_sources test> Compute permutation: False float64 CPU: 2.610 GPU: 1.752
.bss_eval_sources test> Compute permutation: False float64 CPU: 12.979 torch-CPU: 13.382
.bss_eval_sources test> Compute permutation: False float64 CPU: 18.769 GPU: 8.433
*Sources vary across tests
```
## Usage
Expand Down Expand Up @@ -58,9 +58,8 @@ Therefore the expected format is `b, nsrc, samples`
- Version 0.2: `bss_eval_sources` is now backpropagable. There algorithm now accepts batches.
- Version 0.3: Incorporates new PyTorch's fft package for versions>1.7 and deprecates `torch.rfft` and
`torch.ifft` following pytorch's roadmap.

## Warning
It seems there may be differences between GPU and CPU results in some corner cases (due to pytorch issues). CPU version seems to provide good results. CI unit tests are carried out for cpu tensors.
- Version 0.4: Support for PyTorch 1.9 onwards and the new linalg package which deprecates previous algebra solvers.
- Partially Solves inconsistencies between GPU results and CPU results shown at https://github.com/JuanFMontesinos/torch_mir_eval/issues/5

## Current available functions
* Separation:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch==1.8
torch==1.9
mir_eval
numpy
coverage
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
author='Juan Montesinos',
author_email='juanfelipe.montesinos@upf.edu',
packages=find_packages(),
install_requires=['torch>=1.7'],
install_requires=['torch>=1.9'],
classifiers=[
"Programming Language :: Python :: 3", ],
zip_safe=False)
2 changes: 1 addition & 1 deletion torch_mir_eval/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.post2"
__version__ = "0.4"
8 changes: 4 additions & 4 deletions torch_mir_eval/batch_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@

import itertools
import warnings
from math import log2, ceil

import torch
from torch.fft import fft,ifft
from math import log2, ceil
from torch.fft import fft, ifft

from .toeplitz import batch_toeplitz

Expand Down Expand Up @@ -327,11 +327,11 @@ def _project(reference_sources, estimated_source, flen):
# Distortion filters
# TODO case in which only few of the Gs are singular but the rest are determined
if all(torch.det(G) > 0.1):
C = torch.solve(D.unsqueeze(-1), G).solution.reshape(b, nsrc, flen).permute(0, 2, 1)
C = torch.linalg.solve(G, D.unsqueeze(-1)).reshape(b, nsrc, flen).permute(0, 2, 1)
else:
solutions = []
for d, g in zip(D, G):
c = torch.lstsq(d.unsqueeze(1), g).solution.reshape(nsrc, flen).T
c = torch.linalg.lstsq(g, d.unsqueeze(1)).solution.reshape(nsrc, flen).T
solutions.append(c)
C = torch.stack(solutions)
# Filtering
Expand Down
10 changes: 5 additions & 5 deletions torch_mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@

import itertools
import warnings
from math import ceil, log2

import torch
from torch.fft import rfft,ifft,fft
from math import ceil, log2
from torch.fft import ifft, fft

from .toeplitz import toeplitz

# The maximum allowable number of sources (prevents insane computational load)
MAX_SOURCES = 100
#maintained for testing purposes
# maintained for testing purposes

__all__ = ['bss_eval_sources']

Expand Down Expand Up @@ -312,9 +312,9 @@ def _project(reference_sources, estimated_source, flen):
# Computing projection
# Distortion filters
if torch.det(G) > 0.1:
C = torch.solve(D.unsqueeze(1), G).solution.reshape(nsrc, flen).T
C = torch.linalg.solve(G, D.unsqueeze(1)).reshape(nsrc, flen).T
else:
C = torch.lstsq(D.unsqueeze(1), G).solution.reshape(nsrc, flen).T
C = torch.linalg.lstsq(G,D.unsqueeze(1)).solution.reshape(nsrc, flen).T
# Filtering
sproj = torch.zeros(nsampl + flen - 1, **kw)
for i in range(nsrc):
Expand Down

0 comments on commit af324dd

Please sign in to comment.