Skip to content
Snippets Groups Projects
Commit d1e555d3 authored by Daniel Vonk's avatar Daniel Vonk Committed by Daniel Vonk
Browse files

Use cufft instead of fftw in CUDAScatter

parent a38ae1d9
No related branches found
No related tags found
2 merge requests!19Merge develop into main,!17Add CUDA backend
......@@ -47,6 +47,11 @@ SelfVectorsCUDAScatterDevice::SelfVectorsCUDAScatterDevice(
fftw_complex *wspace = nullptr;
fftw_planF_ = fftw_plan_dft_1d(2 * NF, wspace, wspace, FFTW_FORWARD, FFTW_ESTIMATE);
fftw_planB_ = fftw_plan_dft_1d(2 * NF, wspace, wspace, FFTW_BACKWARD, FFTW_ESTIMATE);
auto res = cufftPlan1d(&plan_fwd_, 2 * NF, CUFFT_Z2Z, 1);
if (res != CUFFT_SUCCESS) {
sass::err("Could not create a DFT plan for autocorrelation.");
}
}
void SelfVectorsCUDAScatterDevice::stage_data()
......@@ -460,18 +465,48 @@ void SelfVectorsCUDAScatterDevice::tf_scatter(size_t atomindex, size_t index, si
2 * N, NF, dat_, dscatter_factors, index, atomindex)
.name("self_scatter");
auto dsp = cudaflow.kernel(1, NTHREADS, 0, sass::cuda::square_elements, dat_, 2 * N, NF)
.name("square_dsp");
auto store =
cudaflow
.kernel(1, NTHREADS, 0, sass::cuda::store, dafinal_, da2final_, datfinal_, dat_, N, NF)
.name("store");
kernel.succeed(zero_dat, h2d_scatter_factors, h2d_coords, h2d_subvector_index).precede(dsp);
dsp.precede(d2h_dat_, store);
kernel.succeed(zero_dat, h2d_scatter_factors, h2d_coords, h2d_subvector_index);
store.precede(d2h_dat_);
if (Params::Inst()->scattering.dsp.type == "square") {
auto dsp = cudaflow.kernel(1, NTHREADS, 0, sass::cuda::square_elements, dat_, 2 * N, NF)
.name("square_dsp");
dsp.precede(store).succeed(kernel);
} else if (Params::Inst()->scattering.dsp.type == "autocorrelate") {
auto dsp = cudaflow
.capture([&](tf::cudaFlowCapturer &cpt) {
cpt.on([&](cudaStream_t stream) {
cufftSetStream(plan_fwd_, stream);
auto res = cufftExecZ2Z(plan_fwd_, ( cufftDoubleComplex * )dat_,
( cufftDoubleComplex * )dat_, CUFFT_FORWARD);
if (res != CUFFT_SUCCESS)
sass::err("Error execuring DFT fwd for autocorrelation.");
sass::cuda::autocorrelate_square<<<1, NTHREADS, 0, stream>>>(
dat_, 2 * N, NF);
res = cufftExecZ2Z(plan_fwd_, ( cufftDoubleComplex * )dat_,
( cufftDoubleComplex * )dat_, CUFFT_INVERSE);
if (res != CUFFT_SUCCESS)
sass::err("Error execuring DFT fwd for autocorrelation.");
sass::cuda::autocorrelate_conj_factor<<<1, NTHREADS, 0, stream>>>(
dat_, 2 * N, NF);
});
})
.name("autocorrelate_dsp");
dsp.precede(store).succeed(kernel);
}
// cudaflow.dump(std::cout);
tf::cudaStream stream;
cudaflow.run(stream);
......
......@@ -96,4 +96,7 @@ protected:
~SelfVectorsCUDAScatterDevice();
fftw_plan fftw_planF_;
fftw_plan fftw_planB_;
cufftHandle plan_fwd_;
cufftHandle plan_bkw_;
};
......@@ -94,6 +94,36 @@ __global__ void sass::cuda::square_elements(fftw_complex *at, size_t N, size_t N
}
}
}
// For DSP type "square"
__global__ void sass::cuda::autocorrelate_square(fftw_complex *at, size_t N, size_t NF)
{
if (threadIdx.x < N) {
size_t offset = threadIdx.x * NF;
// get subvector
fftw_complex *data = &(at[offset]);
for (size_t i = 0; i < 2 * NF; ++i) {
data[i][0] = data[i][0] * data[i][0] + data[i][1] * data[i][1];
data[i][1] = 0;
}
}
}
__global__ void sass::cuda::autocorrelate_conj_factor(fftw_complex *at, size_t N, size_t NF)
{
if (threadIdx.x < N) {
size_t offset = threadIdx.x * NF;
// get subvector
fftw_complex *data = &(at[offset]);
for (size_t i = 0; i < NF; ++i) {
double factor = (1.0 / (2.0 * NF * (NF - i)));
data[i][0] *= factor;
data[i][1] *= factor;
}
}
}
__global__ void sass::cuda::store(fftw_complex *afinal, fftw_complex *a2final,
fftw_complex *atfinal, fftw_complex *at, size_t N, size_t NF)
......
......@@ -16,5 +16,7 @@ __global__ void cuda_scatter(coor_t *p_coords, CartesianCoor3D *subvector_index,
__global__ void square_elements(fftw_complex *at, size_t N, size_t NF);
__global__ void store(fftw_complex *afinal, fftw_complex *a2final, fftw_complex *atfinal,
fftw_complex *at, size_t N, size_t NF);
__global__ void autocorrelate_square(fftw_complex *at, size_t N, size_t NF);
__global__ void autocorrelate_conj_factor(fftw_complex *at, size_t N, size_t NF);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment