import os
import libtbx.load_env
Import("env_etc env_simtbx nanoBragg_obj nanoBragg_env")

diffBragg_env = nanoBragg_env.Clone()
if env_etc.enable_kokkos:
  diffBragg_env.Prepend(CPPDEFINES=["DIFFBRAGG_HAVE_KOKKOS"])


env_etc.eigen_dist = os.path.abspath(os.path.join(libtbx.env.dist_path("simtbx"),"../../eigen"))
if not os.path.isdir(env_etc.eigen_dist) and hasattr(env_etc, "conda_cpppath"):
  for candidate in env_etc.conda_cpppath:
    if os.path.isdir(os.path.join(candidate, "eigen3")):
      env_etc.eigen_dist = os.path.abspath(os.path.join(os.path.join(candidate, "eigen3")))
if os.path.isdir(env_etc.eigen_dist):
  env_etc.eigen_include = env_etc.eigen_dist

env_etc.include_registry.append(
  env=diffBragg_env,
  paths=[env_etc.eigen_include]
)

show_all_warnings = os.environ.get("DIFFBRAGG_SHOW_ALL_NVCC_WARNINGS") is not None
if not show_all_warnings:
    warn_suppress_flags="-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored --expt-relaxed-constexpr"
    diffBragg_env.Prepend(NVCCFLAGS=warn_suppress_flags.split())

if (env_etc.enable_cuda):
  diffBragg_env.cudaSharedLibrary(
    target="#lib/libsimtbx_diffBraggCUDA",
    source=["src/diffBraggCUDA.cu", "src/diffBragg_gpu_kernel.cu"])
  diffBragg_env.Prepend(CPPDEFINES=["DIFFBRAGG_HAVE_CUDA"])
  diffBragg_env.Prepend(NVCCFLAGS=["-DDIFFBRAGG_HAVE_CUDA"])
  env_simtbx.Prepend(LIBS=["simtbx_diffBraggCUDA"])

diffBragg_obj = diffBragg_env.SharedObject(
  source=[
    "src/diffBragg_ext.cpp",
    "src/diffBragg.cpp",
    "src/diffBragg_cpu_kernel.cpp"
  ])

libs = []
libpath = []

if (env_etc.enable_kokkos):
  print('Getting kokkos build variables')
  print('-'*79)
  kokkos_cxxflags = os.environ['KOKKOS_CXXFLAGS'].split()
  if kokkos_cxxflags and kokkos_cxxflags[0] == 'echo':
    kokkos_cxxflags = [f.strip('"') for f in kokkos_cxxflags[1:]]
  print('KOKKOS_CXXFLAGS:', kokkos_cxxflags)
  print('='*79)

  '''
  kokkos_cxxflags = ['-std=c++14',
                     '-Xcudafe', 
                     '--diag_suppress=esa_on_defaulted_function_ignored', 
                     '-expt-extended-lambda', 
                     '-arch=sm_80', 
                     '-I./', 
                     '-I/pscratch/sd/f/fwittwer/alcc-recipes/cctbx/modules/kokkos/core/src', 
                     '-I/pscratch/sd/f/fwittwer/alcc-recipes/cctbx/modules/kokkos/containers/src', 
                     '-I/pscratch/sd/f/fwittwer/alcc-recipes/cctbx/modules/kokkos/algorithms/src']
  '''

  if os.getenv('KOKKOS_PATH') is None:
    os.environ['KOKKOS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos')
  #kokkos_env = env_simtbx.Clone()
  if os.getenv('CXX') is not None:
    original_cxx = diffBragg_env['CXX']
  if 'Cuda' in os.getenv('KOKKOS_DEVICES'):
    diffBragg_env.Replace(CXX=os.path.join(os.environ['KOKKOS_PATH'], 'bin', 'nvcc_wrapper'))    #os.environ['CXX'])
    diffBragg_env.Replace(SHCXX=os.path.join(os.environ['KOKKOS_PATH'], 'bin', 'nvcc_wrapper'))  #os.environ['CXX'])
  elif 'HIP' in os.getenv('KOKKOS_DEVICES'):
    diffBragg_env.Replace(CXX='hipcc')
    diffBragg_env.Replace(SHCXX='hipcc') 
  diffBragg_env.Prepend(CXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  diffBragg_env.Prepend(CPPFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  diffBragg_env.Prepend(CPPPATH=[os.environ['KOKKOS_PATH']])
  diffBragg_env.Append(LIBS=['kokkoscontainers','kokkoscore'])

  diffBraggKokkos_env = diffBragg_env.Clone()
  
  # remove NANOBRAGG_HAVE_CUDA and DIFFBRAGG_HAVE_CUDA compile option
  if 'CPPDEFINES' in diffBraggKokkos_env:
    ccflags = diffBraggKokkos_env['CPPDEFINES']
    for o in ['NANOBRAGG_HAVE_CUDA', 'DIFFBRAGG_HAVE_CUDA', 'DIFFBRAGG_HAVE_KOKKOS']:
      if o in ccflags:
        ccflags.remove(o)
  
  if 'NVCCFLAGS' in diffBraggKokkos_env:
    nvccflags = diffBraggKokkos_env['NVCCFLAGS']
    o = '-DDIFFBRAGG_HAVE_CUDA'
    if o in nvccflags:
      nvccflags.remove(o)

  diffBraggKokkos_lib = diffBraggKokkos_env.SharedLibrary(
    target="#lib/libsimtbx_diffBraggKOKKOS",
    source=["src/diffBraggKOKKOS.cpp", "src/diffBragg_kokkos_kernel.cpp", "../../kokkostbx/kokkos_utils.cpp"],
    LIBS=['kokkoscontainers', 'kokkoscore'])

  print("LIBS: ", env_simtbx['LIBS'])

  if 'Cuda' in os.getenv("KOKKOS_DEVICES"):
    libpath += [os.path.join(os.environ['CUDA_HOME'], 'lib64')]
    libpath += [os.path.join(os.environ['CUDA_HOME'], 'lib64/stubs')]
    libpath += [os.path.join(os.environ['CUDA_HOME'], 'compat')]
    libs += ["cudart", "cuda"]
  elif 'HIP' in os.getenv('KOKKOS_DEVICES'):
    libpath += [os.path.join(os.environ['ROCM_PATH'], 'lib')]
    libs += ["amdhip64", "hsa-runtime64"]

env_simtbx.SharedLibrary(
  target="#lib/simtbx_diffBragg_ext",
  source=[diffBragg_obj],
  LIBPATH=env_simtbx['LIBPATH'] + libpath,
  LIBS=['simtbx_diffBraggKOKKOS']*env_etc.enable_kokkos + env_simtbx['LIBS'] + libs)
