snowpark
Collection
JAX ports of popular models(audio, vision, llm, diffusion, etc) • 1 item • Updated
I implemented Dinov2 in JAX for TPU inference, and converted the pretrained weights to JAX. Model code here. Only thing here is the pickle file for the converted pytree/state dict or weights.
Base model
facebook/dinov2-small