Title: TinyLlama: An Open-Source Small Language Model
Authors: Peiyuan Zhang, Guangtao Zeng, Tianduo Wang, Wei Lu
Published: 4th January 2024 (Thursday) @ 17:54:59
Link: http://arxiv.org/abs/2401.02385v2

Abstract

We present TinyLlama, a compact 1.1B language model pretrained on around 1 trillion tokens for approximately 3 epochs. Building on the architecture and tokenizer of Llama 2, TinyLlama leverages various advances contributed by the open-source community (e.g., FlashAttention and Lit-GPT), achieving better computational efficiency. Despite its relatively small size, TinyLlama demonstrates remarkable performance in a series of downstream tasks. It significantly outperforms existing open-source language models with comparable sizes. Our model checkpoints and code are publicly available on GitHub at https://github.com/jzhang38/TinyLlama.


  • 1.1B language model
  • Pretrained on 1 trillion tokens for 3 epochs - so model has seen 3 trillion tokens, i.e. 3x 1T tokens from the dataset?
  • Uses FlashAttention
  • Outperforms models of comparable sizes - on what metrics?
    • “Specifically, TinyLlama surpasses both OPT-1.3B (Zhang et al., 2022) and Pythia-1.4B (Biderman et al., 2023) in various downstream tasks.”
  • Model checkpoints and code on GitHub: https://github.com/jzhang38/TinyLlama
  • Training Compute-Optimal Large Language Models (Hoffman 2022) says increase data and compute proportionally
  • Chinchilla’s Death:
    • Chinchilla paper stopped training models at point where smaller models’ loss was undershot by larger models’ so don’t have convincing picture of what training loss curves will look like if training is continued with the smaller models
    • Llama (1) observations on loss curves:
      • Loss curves for larger models is pretty consistently below those of the smaller models - 7, 13, 33, 65 billion parameter models; tested from 0 to 1400 billion (=1.4 trillion) tokens
      1. Each curve first plummets in a power law,
      2. and then seemingly enters a nearly-linear decrease in loss (corresponding to a fairly constant rate of knowledge acquisition).
      3. At the very tip of the curve, they all break this line by flattening slightly.
    • Blog post author says the flattening loss at the end is because the cosine schedule of the learning rate decreases it to zero - if we have more tokens we could stretch out the decrease in learning rate
    • Training the smaller models with an equal amount of compute - not for the same number of tokens - may yield better performance
    • I guess there’s a saturation in terms of information storage capacity - what papers are there on this?
  • Data - Trained on:
    • SlimPajama: a de-duplicated and filtered version of RedPyjama (1.2 trillion tokens; SlimPajama retains ~0.6T tokens; 50% of RedPJ) - remove the GitHub section of SlimPajama since they use StarCoder data
    • StarCoder training set: Code in 86 languages and GitHub issues data (natural language-code source pairs)
    • 950B tokens
    • Llama tokenizer
    • SlimPajama:StarCode 7:3 ratio
  • Architecture:
    • RoPE like Llama 2, Qwen and PaLM
    • Pre-norm and RMSNorm
    • SwiGLU (not ReLU)
    • Grouped Query Attention
  • Training:
    • FSDP via https://huggingface.co/docs/accelerate/usage_guides/fsdp
    • FlashAttention
    • Fused layernorm, fused cross-entropy loss, fused RoPE, fused SwiGLU (from xFormers)
    • Can use larger batch size due to fused kernels
      • Why do fused kernels allow for larger batch sizes?
    • Training throughput of 24,000 tokens/second on a A100-40GB
      • Needed 3,456 hours to train on 300B tokens vs 4,830 for Pythia-1.0B, 7,920 hours for MPT-1.3B
  • Hyperparameters:
  • “Consistent with Llama 2’s settings, we utilize the AdamW optimizer (Loshchilov and Hutter, 2019), setting ÎČ1 at 0.9 and ÎČ2 at 0.95. Additionally, we use a cosine learning rate schedule with a maximum learning rate of 4.0 × 10−4 and a minimum learning rate of 4.0 × 10−5 . We use 2,000 warmup steps to facilitate optimized learning. We set the batch size as 2M tokens. We assign weight decay as 0.1 and use a gradient clipping threshold of 1.0 to regulate the gradient value. We pretrain TinyLlama with 16 A100-40G GPUs in our project.”
  • Version 1.1:
    • Fixed issues with LR schedule and data loading - see https://github.com/jzhang38/TinyLlama/issues/27 and https://github.com/jzhang38/TinyLlama/issues/67
    • Reduced tokens to 2T from 3T - ds performance still improved
    • “three-stage training”: pre-training, continued training on targeted domains, cooldown
      • Pretraining: 1.5T tokens of SlimPajama
      • Three separate CPT variants: Code and mathematical reasoning combining (1) StarCoder Python and Jupyter subsets (2) the Proof Pile 2 dataset (3) the Skypile Chinese dataset
        • All with 350B tokens
        • TinyLlama v1.1: A foundational model for general applications.
        • TinyLlama v1.1 - Math&Code: Enhanced specifically for mathematical and coding tasks. (1) and (2) above
        • TinyLlama v1.1 - Chinese: Specialized for processing and understanding Chinese text. (3) above
          • How did they have the tokenizer handle the Skypile Chinese dataset? Are the Chinese characters already in the tokenizer? Anything else would be a mess
        • When switching data distribution, they do this smoothly by increasing sampling proportion of the other domain tokens for the first 6B tokens (not sure if that means 6B tokens during training, or 6B tokens of the new domain - I think the first)
      • Cooldown: Done by increasing the batch size at the end of training to 7.2M tokens from 1.8M since there is a cosine learning rate schedule and this means the LR is already low (cooldown the usual way - decreasing LR - would maybe stop learning; what’s mentioned in the Chinchilla’s Death that they reference)
  • Evaluation:
    • Beats OPT-1.3B and Pythia-1.0B / Pythia-1.4B on HellaSwag, Obqa, WinoGrande, ARC-c, ARC-e, piqa
    • Pythia-1.4B best on boolq
    • Wins on MMLU (Measuring Massive Multitask Language Understanding), HumanEval - a couple others I didn’t know: BBH and DROP
    • Eval on Chinese: v1.0 and v1.1 Math&Code beat v1.1 - surprising
      • discovered that 3.5% of Python and 4% of Jupyter data consist of Chinese text
  • SlimPajama is on Hugging Face datasets as cerebras/SlimPajama-627B!