Skip to content

Commit

Permalink
fix the bwlcv methods for multivariate kde
Browse files Browse the repository at this point in the history
  • Loading branch information
panlanfeng committed Sep 17, 2017
1 parent f776585 commit f7821fd
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
11 changes: 8 additions & 3 deletions src/bandwidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ function bwlcv(xdata::RealVector, kernel::Function)
return Optim.minimizer(Optim.optimize(h->lcv(xdata,kernel,h,w,n), hlb, hub, iterations=200,abs_tol=h0/n^2))
end

function lcv(xdata::RealMatrix, kernel::Array{Function, 1}, h::RealVector, w::Vector, n::Int)
function lcv(xdata::RealMatrix, kernel::Vector, h::RealVector, w::Vector, n::Int)
# -mean(kerneldensity(xdata,xdata,kernel,h)) + mean(map(kernel, xdata, xdata, h))
if any(h .<= 0.0)
return Inf
Expand All @@ -170,7 +170,7 @@ function lcv(xdata::RealMatrix, kernel::Array{Function, 1}, h::RealVector, w::Ve
end
-ll
end
function bwlcv(xdata::RealMatrix, kernel::Array{Function, 1})
function bwlcv(xdata::RealMatrix, kernel::Vector)
n, p = size(xdata)
w = ones(n)
h0 = zeros(p)
Expand All @@ -191,7 +191,12 @@ function bwlcv(xdata::RealMatrix, kernel::Array{Function, 1})
hub[j] = h0[j]
end
end
Optim.minimizer(Optim.optimize(h->lcv(xdata, kernel, h, w, n), h0))
h = Optim.minimizer(Optim.optimize(h->lcv(xdata, kernel, h, w, n), h0))
if all(hlb .<= h .<= hub)
return h
else
return h0
end
end


Expand Down
9 changes: 5 additions & 4 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ function kerneldensity(xdata::RealVector; xeval::RealVector=xdata, lb::Real=-Inf
return den
end

function kerneldensity(xdata::RealMatrix; xeval::RealMatrix=xdata,
kernel::Array{Function, 1}=[gaussiankernel for i in 1:size(xdata)[2]], h::RealVector=bwlcv(xdata, kernel))
function kerneldensity(xdata::RealMatrix; xeval::RealMatrix=xdata,
kernel::Vector=[gaussiankernel for i in 1:size(xdata)[2]], h::RealVector=-Inf .* ones(size(xdata, 2)))

if any(h .<= 0)
error("h < 0!")
h = bwlcv(xdata, kernel)
warn("The user are responsible for giving the bandwidth! The defaults may not work well.")
end
m, p=size(xeval)
n, p1 = size(xdata)
Expand Down
6 changes: 6 additions & 0 deletions test/testreg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

##Univariate kerneldensity and regression
using Distributions
srand(2017);
x=rand(Normal(10), 500)
xeval=linspace(minimum(x), maximum(x), 100)
h = bwlscv(x, gaussiankernel)
Expand Down Expand Up @@ -28,6 +29,11 @@ yfit1=npr(x, y, xeval=xeval, reg=locallinear)
cb=bootstrapCB(x, y, xeval=xeval)
@test mean(vec(cb[1,:]) .<= yfit1 .<= vec(cb[2,:])) > .8

#multivariate density estimation
x = rand(Normal(10), 500, 3)
denvalues = kerneldensity(x, h = [1.0, 1.0, 1.0])
@test all(denvalues .>= 0)

#multivariate regression
x = rand(Normal(10), 500, 3)
y = x * ones(3) .+ x.^2 *ones(3)
Expand Down

0 comments on commit f7821fd

Please sign in to comment.