Skip to content

Commit

Permalink
use amrex removeInvalidParticles for the beam (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn authored Feb 22, 2024
1 parent c2f50ff commit 4a257f4
Showing 1 changed file with 17 additions and 76 deletions.
93 changes: 17 additions & 76 deletions src/particles/sorting/SliceSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,21 @@ shiftSlippedParticles (BeamParticleContainer& beam, const int slice, amrex::Geom
return;
}

const auto ptd = beam.getBeamSlice(WhichBeamSlice::This).getParticleTileData();
// remove all invalid particles from WhichBeamSlice::This (including slipped)
amrex::removeInvalidParticles(beam.getBeamSlice(WhichBeamSlice::This));

// min_z is the lower end of WhichBeamSlice::This
const amrex::Real min_z = geom.ProbLo(2) + (slice-geom.Domain().smallEnd(2))*geom.CellSize(2);

amrex::ReduceOps<amrex::ReduceOpSum, amrex::ReduceOpSum> reduce_op;
amrex::ReduceData<int, int> reduce_data(reduce_op);
using ReduceTuple = typename decltype(reduce_data)::Type;

const int num_particles = beam.getNumParticlesIncludingSlipped(WhichBeamSlice::This);

// count the number of invalid and slipped particles
reduce_op.eval(
num_particles, reduce_data,
[=] AMREX_GPU_DEVICE (const int ip) -> ReduceTuple
{
if (ptd.id(ip) < 0) {
return {1, 0};
} else if (ptd.pos(2, ip) < min_z) {
return {0, 1};
} else {
return {0, 0};
}
// put non slipped particles at the start of the slice
const int num_stay = amrex::partitionParticles(beam.getBeamSlice(WhichBeamSlice::This),
[=] AMREX_GPU_DEVICE (auto& ptd, int i) {
return ptd.pos(2, i) >= min_z;
});

ReduceTuple t = reduce_data.value();

const int num_invalid = amrex::get<0>(t);
const int num_slipped = amrex::get<1>(t);
const int num_stay = beam.getNumParticlesIncludingSlipped(WhichBeamSlice::This)
- num_invalid - num_slipped;
const int num_slipped = beam.getBeamSlice(WhichBeamSlice::This).size() - num_stay;

if (num_invalid == 0 && num_slipped == 0) {
if (num_slipped == 0) {
// nothing to do
beam.resize(WhichBeamSlice::This, num_stay, 0);
return;
Expand All @@ -64,60 +46,19 @@ shiftSlippedParticles (BeamParticleContainer& beam, const int slice, amrex::Geom

beam.resize(WhichBeamSlice::Next, next_size, num_slipped);

BeamTile tmp{};
tmp.resize(num_stay);

const auto ptd_tmp = tmp.getParticleTileData();

// copy valid non slipped particles to the tmp tile
const int num_stay2 = amrex::Scan::PrefixSum<int> (num_particles,
[=] AMREX_GPU_DEVICE (const int ip) -> int
{
return ptd.id(ip) >= 0 && ptd.pos(2, ip) >= min_z;
},
[=] AMREX_GPU_DEVICE (const int ip, const int s)
{
if (ptd.id(ip) >= 0 && ptd.pos(2, ip) >= min_z) {
ptd_tmp.idcpu(s) = ptd.idcpu(ip);
for (int j=0; j<ptd_tmp.NAR; ++j) {
ptd_tmp.rdata(j)[s] = ptd.rdata(j)[ip];
}
for (int j=0; j<ptd_tmp.NAI; ++j) {
ptd_tmp.idata(j)[s] = ptd.idata(j)[ip];
}
}
},
amrex::Scan::Type::exclusive);

AMREX_ALWAYS_ASSERT(num_stay == num_stay2);

const auto ptd_this = beam.getBeamSlice(WhichBeamSlice::This).getParticleTileData();
const auto ptd_next = beam.getBeamSlice(WhichBeamSlice::Next).getParticleTileData();

// copy valid slipped particles to WhichBeamSlice::Next
const int num_slipped2 = amrex::Scan::PrefixSum<int> (num_particles,
[=] AMREX_GPU_DEVICE (const int ip) -> int
amrex::ParallelFor(num_slipped,
[=] AMREX_GPU_DEVICE (int i)
{
return ptd.id(ip) >= 0 && ptd.pos(2, ip) < min_z;
},
[=] AMREX_GPU_DEVICE (const int ip, const int s)
{
if (ptd.id(ip) >= 0 && ptd.pos(2, ip) < min_z) {
ptd_next.idcpu(s+next_size) = ptd.idcpu(ip);
for (int j=0; j<ptd_next.NAR; ++j) {
ptd_next.rdata(j)[s+next_size] = ptd.rdata(j)[ip];
}
for (int j=0; j<ptd_next.NAI; ++j) {
ptd_next.idata(j)[s+next_size] = ptd.idata(j)[ip];
}
}
},
amrex::Scan::Type::exclusive);

AMREX_ALWAYS_ASSERT(num_slipped == num_slipped2);
// copy particles from WhichBeamSlice::This to WhichBeamSlice::Next
amrex::copyParticle(ptd_next, ptd_this, num_stay + i, next_size + i);
});

beam.getBeamSlice(WhichBeamSlice::This).swap(tmp);
beam.resize(WhichBeamSlice::This, num_stay, 0);

// stream sync before tmp is deallocated
// stream sync before WhichBeamSlice::This is resized
amrex::Gpu::streamSynchronize();

beam.resize(WhichBeamSlice::This, num_stay, 0);
}

0 comments on commit 4a257f4

Please sign in to comment.