Merge pull request #113207 from danieldk/pytorch-cuda-11

python3Packages.pytorch: add compute capabilities for CUDA 11
wip/yesman
Daniël de Kok 3 years ago committed by GitHub
commit e9b3e36f44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 44
      pkgs/development/python-modules/pytorch/default.nix

@ -74,27 +74,35 @@ let
# (allowing FBGEMM to be built in pytorch-1.1), and may future proof this
# derivation.
brokenArchs = [ "3.0" ]; # this variable is only used as documentation.
cuda9ArchList = [
"3.5"
"5.0"
"5.2"
"6.0"
"6.1"
"7.0"
"7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9
# which leads me to believe this is the final cuda-9-compatible architecture.
];
cuda10ArchList = cuda9ArchList ++ [
"7.5"
"7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0
];
cudaCapabilities = rec {
cuda9 = [
"3.5"
"5.0"
"5.2"
"6.0"
"6.1"
"7.0"
"7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9
# which leads me to believe this is the final cuda-9-compatible architecture.
];
cuda10 = cuda9 ++ [
"7.5"
"7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0
];
cuda11 = cuda10 ++ [
"8.0"
"8.0+PTX" # < CUDA toolkit 11.0
"8.6"
"8.6+PTX" # < CUDA toolkit 11.1
];
};
final_cudaArchList =
if !cudaSupport || cudaArchList != null
then cudaArchList
else
if lib.versions.major cudatoolkit.version == "9"
then cuda9ArchList
else cuda10ArchList; # the assert above removes any ambiguity here.
else cudaCapabilities."cuda${lib.versions.major cudatoolkit.version}";
# Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
# LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub

Loading…
Cancel
Save