Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/shammodels/bindings/pyShamrockCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void append_to_map(

{
auto acc = field.get_buf().copy_to_stdvec();
u32 len = field.size();
u32 len = field.get_val_cnt();

for (u32 i = 0; i < len; i++) {
vec.push_back(acc[i]);
Expand Down
24 changes: 20 additions & 4 deletions src/shammodels/sph/Model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,18 @@ namespace shammodels::sph {
PatchDataField<T> &f
= pdat.template get_field<T>(sched.pdl.get_field_idx<T>(field_name));

if (f.get_nvar() != 1) {
shambase::throw_unimplemented();
}

{
auto &buf = f.get_buf();
auto acc = buf.copy_to_stdvec();

auto &buf_xyz = xyz.get_buf();
auto acc_xyz = buf_xyz.copy_to_stdvec();

for (u32 i = 0; i < f.size(); i++) {
for (u32 i = 0; i < f.get_obj_cnt(); i++) {
Tvec r = acc_xyz[i];

acc[i] = pos_to_val(r);
Expand Down Expand Up @@ -590,11 +594,15 @@ namespace shammodels::sph {
PatchDataField<T> &f
= pdat.template get_field<T>(sched.pdl.get_field_idx<T>(field_name));

if (f.get_nvar() != 1) {
shambase::throw_unimplemented();
}

{
auto acc = f.get_buf().template mirror_to<sham::host>();
auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();

for (u32 i = 0; i < f.size(); i++) {
for (u32 i = 0; i < f.get_obj_cnt(); i++) {
Tvec r = acc_xyz[i];

if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
Expand All @@ -617,12 +625,16 @@ namespace shammodels::sph {
PatchDataField<T> &f
= pdat.template get_field<T>(sched.pdl.get_field_idx<T>(field_name));

if (f.get_nvar() != 1) {
shambase::throw_unimplemented();
}

Tscal r2 = radius * radius;
{
auto acc = f.get_buf().template mirror_to<sham::host>();
auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();

for (u32 i = 0; i < f.size(); i++) {
for (u32 i = 0; i < f.get_obj_cnt(); i++) {
Tvec dr = acc_xyz[i] - center;

if (sycl::dot(dr, dr) < r2) {
Expand All @@ -645,11 +657,15 @@ namespace shammodels::sph {
PatchDataField<T> &f
= pdat.template get_field<T>(sched.pdl.get_field_idx<T>(field_name));

if (f.get_nvar() != 1) {
shambase::throw_unimplemented();
}

{
auto acc = f.get_buf().template mirror_to<sham::host>();
auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();

for (u32 i = 0; i < f.size(); i++) {
for (u32 i = 0; i < f.get_obj_cnt(); i++) {
Tvec dr = acc_xyz[i] - center;

Tscal r = sycl::length(dr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ namespace patchdata_field {
i32 tag,
MPI_Comm comm) {

rq_lst.emplace_back(p, current_mode, Send, p.size());
rq_lst.emplace_back(p, current_mode, Send, p.get_val_cnt());

u32 rq_index = rq_lst.size() - 1;

auto &rq = rq_lst[rq_index];

mpi::isend(
rq.get_mpi_ptr(),
p.size(),
p.get_val_cnt(),
get_mpi_type<T>(),
rank_dest,
tag,
comm,
&(rq_lst[rq_index].mpi_rq));

return sizeof(T) * p.size();
return sizeof(T) * p.get_val_cnt();
}

template<class T>
Expand Down Expand Up @@ -145,9 +145,9 @@ namespace patchdata_field {
inline void file_write(MPI_File fh, PatchDataField<T> &p) {
MPI_Status st;

PatchDataFieldMpiRequest<T> rq(p, current_mode, Send, p.size());
PatchDataFieldMpiRequest<T> rq(p, current_mode, Send, p.get_val_cnt());

mpi::file_write(fh, rq.get_mpi_ptr(), p.size(), get_mpi_type<T>(), &st);
mpi::file_write(fh, rq.get_mpi_ptr(), p.get_val_cnt(), get_mpi_type<T>(), &st);

rq.finalize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ get_new_id_map<f32_3>(PatchScheduler &sched, SerialPatchTree<f32_3> &sptree) {
u32 ixyz = sched.pdl.get_field_idx<f32_3>("xyz");
PatchDataField<f32_3> &xyz_field = pdat.get_field<f32_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});
}
});

Expand All @@ -69,12 +75,18 @@ get_new_id_map<f64_3>(PatchScheduler &sched, SerialPatchTree<f64_3> &sptree) {
u32 ixyz = sched.pdl.get_field_idx<f64_3>("xyz");
PatchDataField<f64_3> &xyz_field = pdat.get_field<f64_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});
}
});

Expand All @@ -101,12 +113,18 @@ reatribute_particles<f32_3>(PatchScheduler &sched, SerialPatchTree<f32_3> &sptre
u32 ixyz = sched.pdl.get_field_idx<f32_3>("xyz");
PatchDataField<f32_3> &xyz_field = pdat.get_field<f32_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});

{
// auto nid = newid_buf_map.at(id).get_access<sycl::access::mode::read>();
Expand Down Expand Up @@ -187,12 +205,18 @@ reatribute_particles<f32_3>(PatchScheduler &sched, SerialPatchTree<f32_3> &sptre
u32 ixyz = sched.pdl.get_field_idx<f32_3>("xyz");
PatchDataField<f32_3> &xyz_field = pdat.get_field<f32_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});
}
});
}
Expand Down Expand Up @@ -327,12 +351,18 @@ reatribute_particles<f64_3>(PatchScheduler &sched, SerialPatchTree<f64_3> &sptre
u32 ixyz = sched.pdl.get_field_idx<f64_3>("xyz");
PatchDataField<f64_3> &xyz_field = pdat.get_field<f64_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});

{
// auto nid = newid_buf_map.at(id).get_access<sycl::access::mode::read>();
Expand Down Expand Up @@ -414,12 +444,18 @@ reatribute_particles<f64_3>(PatchScheduler &sched, SerialPatchTree<f64_3> &sptre
u32 ixyz = sched.pdl.get_field_idx<f64_3>("xyz");
PatchDataField<f64_3> &xyz_field = pdat.get_field<f64_3>(ixyz);

if (xyz_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

auto &pos = xyz_field.get_buf();

newid_buf_map.insert(
{id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(), pos, xyz_field.size())});
shamsys::instance::get_compute_scheduler_ptr(),
pos,
xyz_field.get_obj_cnt())});
}
});
}
Expand Down
5 changes: 4 additions & 1 deletion src/shamrock/include/shamrock/patch/PatchData.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ namespace shamrock::patch {

/**
* @brief Fetch data of a patchdata field into a std::vector
*
* @todo Improve for nvar != 1
*
* @tparam T
* @param key
* @param pdat
Expand All @@ -417,7 +420,7 @@ namespace shamrock::patch {

if (!field.is_empty()) {
auto acc = field.get_buf().copy_to_stdvec();
u32 len = field.size();
u32 len = field.get_val_cnt();

for (u32 i = 0; i < len; i++) {
vec.push_back(acc[i]);
Expand Down
27 changes: 22 additions & 5 deletions src/shamrock/include/shamrock/patch/PatchDataField.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,25 @@ class PatchDataField {

inline sham::DeviceBuffer<T> &get_buf() { return buf; }

[[nodiscard]] inline u32 size() const { return buf.get_size(); }

[[nodiscard]] inline bool is_empty() const { return size() == 0; }
[[nodiscard]] inline bool is_empty() const { return get_obj_cnt() == 0; }

[[nodiscard]] inline u64 memsize() const { return buf.get_mem_usage(); }

[[nodiscard]] inline const u32 &get_nvar() const { return nvar; }

[[nodiscard]] inline const u32 &get_obj_cnt() const { return obj_cnt; }

/**
* @brief Get the number of values stored in the field.
*
* This function was introduced to replace the legacy one size() which could be confused with
* the of the buffer, which is not required to be the same.
*
* @return u32 the total number of values of the field, which is the product of the number of
* objects and the number of variables per object.
*/
[[nodiscard]] inline u32 get_val_cnt() const { return get_obj_cnt() * get_nvar(); }

[[nodiscard]] inline const std::string &get_name() const { return field_name; }

// TODO add overflow check
Expand Down Expand Up @@ -523,10 +532,14 @@ PatchDataField<T>::get_elements_with_range(Lambdacd &&cd_true, T vmin, T vmax) {
}
*/

if (nvar != 1) {
shambase::throw_unimplemented();
}

{
auto acc = buf.copy_to_stdvec();

for (u32 i = 0; i < size(); i++) {
for (u32 i = 0; i < get_val_cnt(); i++) {
if (cd_true(acc[i], vmin, vmax)) {
idxs.push_back(i);
}
Expand Down Expand Up @@ -572,12 +585,16 @@ PatchDataField<T>::check_err_range(Lambdacd &&cd_true, T vmin, T vmax, std::stri
return;
}

if (nvar != 1) {
shambase::throw_unimplemented();
}

bool error = false;
{
auto acc = buf.copy_to_stdvec();
u32 err_cnt = 0;

for (u32 i = 0; i < size(); i++) {
for (u32 i = 0; i < get_val_cnt(); i++) {
if (!cd_true(acc[i], vmin, vmax)) {
logger::err_ln(
"PatchDataField",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ namespace shamrock {

PatchDataField<T> &pos_field = pdat.get_field<T>(ipos);

if (pos_field.get_nvar() != 1) {
shambase::throw_unimplemented();
}

newid_buf_map.add_obj(
id,
sptree.compute_patch_owner(
shamsys::instance::get_compute_scheduler_ptr(),
pos_field.get_buf(),
pos_field.size()));
pos_field.get_obj_cnt()));

bool err_id_in_newid = false;
{
Expand Down
2 changes: 1 addition & 1 deletion src/shamrock/src/legacy/patch/utility/merged_patch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ auto MergedPatchCompField<flt, T>::merge_patches_cfield(

auto &merged_field = merged_data.at(id_patch);

merged_field.or_element_cnt = compfield.size();
merged_field.or_element_cnt = compfield.get_val_cnt();
merged_field.buf.insert(compfield);

std::vector<std::tuple<u64, std::unique_ptr<PatchDataField<T>>>> &p_interf_lst
Expand Down
Loading