Skip to content

Commit

Permalink
Merge pull request #872 from pavanky/modfix
Browse files Browse the repository at this point in the history
correctness fixes for mod and remainder for integer type tensors.
  • Loading branch information
soumith authored Dec 20, 2016
2 parents f9b37b0 + e1755e4 commit 7ca7ec9
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,19 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, real value)
ptrdiff_t sz = THTensor_(nElement)(t);
ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
for (i=0; i<sz; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = fmod(tp[i], value);
#else
rp[i] = tp[i] % value;
#endif
}
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = fmod(*t_data, value););
#else
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data % value););
#endif
}
}

Expand All @@ -532,10 +541,20 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
ptrdiff_t sz = THTensor_(nElement)(t);
ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
for (i=0; i<sz; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = (value == 0)? NAN : tp[i] - value * floor(tp[i] / value);
#else
rp[i] = tp[i] - value * (tp[i] / value); // There is no NAN for integers
#endif
}
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (value == 0)? NAN : *t_data - value * floor(*t_data / value););
#else
// There is no NAN for integers
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data - value * (*t_data / value););
#endif
}
}

Expand Down Expand Up @@ -643,10 +662,20 @@ void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src)
ptrdiff_t sz = THTensor_(nElement)(t);
ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = fmod(tp[i], sp[i]);
for (i=0; i<sz; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = fmod(tp[i], sp[i]);
#else
rp[i] = tp[i] % sp[i];
#endif
}
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = fmod(*t_data, *src_data););
#else
TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = (*t_data % *src_data););
#endif

}
}

Expand All @@ -660,10 +689,21 @@ void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src)
ptrdiff_t sz = THTensor_(nElement)(t);
ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
for (i=0; i<sz; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = (sp[i] == 0)? NAN : tp[i] - sp[i] * floor(tp[i] / sp[i]);
#else
rp[i] = tp[i] - sp[i] * (tp[i] / sp[i]); // There is no NAN for integers
#endif
}
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = (*src_data == 0)? NAN : *t_data - *src_data * floor(*t_data / *src_data););
#else
// There is no NAN for integers
TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data - *src_data * (*t_data / *src_data););
#endif

}
}

Expand Down

0 comments on commit 7ca7ec9

Please sign in to comment.