Map Reduction

code
analysis
llm
Author

Nora Kristiansen, Torjørn Vatnelid

Published

February 15, 2024

A good way to work around token limitations. If you are summarizing text that is too long for your chosen model’s context window, Map Reduction can be used.

Map Reduction works like this: - The text is broken up into manageable pieces (chunks) - A summary is generated for each chunk - A final summary is generated from all the chunk summaries

This method is well suited for generating summaries of some types of text, but has some limitations: - The model may over or underemphazise certain aspects of the text - Gets really expensive really fast, as you need many calls to the LLM and lots of input tokens.

from dotenv import load_dotenv
from utils import read_files, split_document_by_tokens
from pathlib import Path

import os

load_dotenv()
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
documents = read_files(Path('../../content/books'))
summary_map_template = """Write a short summary of the following text:

{context}

SUMMARY:
"""

summary_reduce_template = """The following text is a set of summaries:

{doc_summaries}

Create a cohesive summary from the above text.
SUMMARY:"""
from langchain.chains import LLMChain, ReduceDocumentsChain, MapReduceDocumentsChain, StuffDocumentsChain
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

def summarize_document(document: list[Document]):
    # Chain to generate a summary from each chunk
    llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model="gpt-4-0125-preview")
    map_prompt = PromptTemplate.from_template(summary_map_template)
    map_chain = LLMChain(prompt=map_prompt, llm=llm)

    # Chain to generate one cohesive summary from the summaries
    reduce_prompt = PromptTemplate.from_template(summary_reduce_template)
    reduce_chain = LLMChain(prompt=reduce_prompt, llm=llm)
    stuff_chain = StuffDocumentsChain(llm_chain=reduce_chain, document_variable_name="doc_summaries")
    reduce_docs_chain = ReduceDocumentsChain(combine_documents_chain=stuff_chain)

    # The complete map reduction chain
    map_reduce_chain = MapReduceDocumentsChain(
        llm_chain=map_chain,
        document_variable_name="content",
        reduce_documents_chain=reduce_docs_chain
    )

    splitdocs = split_document_by_tokens(document, 15000, 200)
    summary = map_reduce_chain.run(splitdocs)
    return summary