-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathjaxlib.Dockerfile
More file actions
41 lines (29 loc) · 1.21 KB
/
jaxlib.Dockerfile
File metadata and controls
41 lines (29 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
ARG BASE_IMAGE
FROM ${BASE_IMAGE} AS builder
ARG PACKAGE_VERSION
ARG CUDA_MAJOR_VERSION
ARG CUDA_MINOR_VERSION
# Make sure we are on the right version of CUDA
RUN update-alternatives --set cuda /usr/local/cuda-$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION
# Ensures shared libraries installed with conda can be found by the dynamic link loader.
# For PyTorch, we need specifically mkl.
ENV LIBRARY_PATH="$LIBRARY_PATH:/opt/conda/lib"
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib"
# Instructions: https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source
RUN sudo ln -s /usr/bin/python3 /usr/bin/python
RUN apt-get update && \
apt-get install -y g++ python3 python3-dev
RUN pip install numpy wheel build
RUN cd /usr/local/src && \
git clone https://github.com/google/jax && \
cd jax && \
git checkout jaxlib-v$PACKAGE_VERSION
RUN cd /usr/local/src/jax && \
python build/build.py --enable_cuda
# Using multi-stage builds to ensure the output image is very small
# See: https://docs.docker.com/develop/develop-images/multistage-build/
FROM alpine:latest
RUN mkdir -p /tmp/whl/
COPY --from=builder /usr/local/src/jax/dist/*.whl /tmp/whl
# Print out the built .whl file.
RUN ls -lh /tmp/whl/