Transformers have become the cornerstone of advancements in natural language processing (NLP), powering breakthroughs in tasks ranging from machine translation to large-scale language understanding. Yet, their ability to handle long-context reasoning has remained a challenge. LM2: Large Memory Model is a novel architecture designed to address these limitations while maintaining the performance of existing Transformers.
LM2 builds on the foundation of the decoder-only Transformer by introducing a memory module. This auxiliary module acts as a repository for long-term contextual representations, interacting with input embeddings via cross-attention. Through dynamic gating—comprising input, output, and forget gates—the memory module enables selective updates and retrievals, ensuring critical information is neither lost nor overwritten.
This architecture is visualized in the LM2 diagram (Figure 1), where the primary Transformer layers coexist with an additional memory pathway. This design ensures that the memory augmentation complements the core Transformer capabilities rather than interfering with them.
Memory information flow enables LM2 to retrieve and integrate relevant information dynamically. Input embeddings act as queries that interact with the memory bank, which serves as both a key-value store. Through cross-attention mechanisms, the model identifies memory slots containing the most relevant data and dynamically incorporates it into the Transformer’s attention layers. This process ensures a seamless blending of contextual and long-term memory.
Memory updates are designed to maintain the relevance and efficiency of the memory module. The input gate determines how much new information should be written into memory, while the forget gate selectively discards outdated content.
LM2’s efficacy shines on benchmarks like BABILong, a benchmark specifically crafted to stress-test memory-intensive reasoning. Across various tasks (Table 1) —from multi-hop inference to relational argumentation—LM2 consistently outperforms its competitors:
LM2 demonstrated its strength even with context lengths as long as 128K tokens (the maximum length supported by Llama-3.2), showcasing its ability to retrieve and synthesize information across vast inputs.
LM2 excels across various reasoning tasks in the BABILong benchmark, including single-step reasoning, multi-step reasoning, relation tracking, and basic queries. Its performance is particularly notable in complex tasks like multi-hop inference, where it surpasses competing models by a significant margin.
To further evaluate if introducing an extra memory module affects LLMs’ general performance, we evaluate the proposed memory-based model, LM2, on the MMLU benchmark, which tests a broad spectrum of subject areas—STEM, Humanities, Social Sciences, and Others—as well as varied difficulty levels—High School, College, Professional, and General Knowledge. Table 2 presents the results of LM2 in comparison to vanilla-Llama and RMT.
We evaluate the effectiveness of proposed memory modules using perplexity as the primary metric across varying numbers of training tokens (measured in billions).
The figure below illustrates the perplexity trends for the baseline vanilla-Llama and LM2 with varying degrees of memory integration (i.e., 1, 6, 12, and 16 blocks).
The results demonstrate that integrating memory information more extensively throughout the decoder leads to improved model performance.
Specifically, implementing the memory module in only the first block achieves similar results to the vanilla Llama, but with slower convergence. This suggests that introducing a single memory flow does not degrade overall performance but may slow down training because of extra memory optimization.
We further investigate how memory updates influence model generation during test time.
Figure 6a shows the cross attention heatmap prior to memory updates. In this figure, tokens such as “France” and “Paris” strongly engage with the memory. These tokens do not pertain specifically to the target question about photosynthesis. Instead, on the first pass, memory initially focuses on the structure of the question as well as identifying factual information.
Next, we examine the memory heatmap after various inference update steps (one inference step corresponds to a single forward pass for one token). As depicted in Figure 6b, the tokens attended to by the memory slots shift toward those relevant to the target question. Since cross attention exclusively computes the relationships between input tokens and memory, this shift reflects the influence of test-time memory updates. These changes highlight the adaptive nature of memory during inference.
For further details, access the full research paper and explore the codebase on GitHub. Join us as we redefine the boundaries of what Transformers can achieve!