Source code for datable_ai.output
import os
import tiktoken
from langchain.chains.summarize import load_summarize_chain
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import CharacterTextSplitter
from datable_ai.core.llm import (
LLM_TYPE,
create_langfuse_handler,
create_llm,
)
[docs]
class Output:
"""
A class for generating output using a language model.
Args:
llm_type (LLM_TYPE): The type of language model to use.
prompt_template (str): The prompt template to use for generating output.
"""
def __init__(
self,
llm_type: LLM_TYPE,
prompt_template: str,
) -> None:
self.llm_type = llm_type
self.prompt_template = prompt_template
self.prompt = ChatPromptTemplate.from_template(self.prompt_template)
self.llm = create_llm(self.llm_type)
self.encoding_name = self._encoding_name()
self.max_tokens = self._max_tokens(self.encoding_name)
[docs]
def invoke(self, **kwargs):
"""
Generates output using the language model.
Args:
**kwargs: Keyword arguments to pass to the prompt template.
Returns:
The generated output.
Raises:
RuntimeError: If an error occurs while generating output.
"""
try:
summarized_kwargs = {}
for key, value in kwargs.items():
num_tokens = self._num_tokens_from_string(value)
if num_tokens > self.max_tokens:
summarized_value = self._summarize(value)
summarized_kwargs[key] = summarized_value["output_text"]
else:
summarized_kwargs[key] = value
chain = self.prompt | self.llm | StrOutputParser()
return chain.invoke(
summarized_kwargs,
config={"callbacks": [create_langfuse_handler()]},
)
except Exception as e:
raise RuntimeError(f"Error invoking Output: {str(e)}") from e
def _num_tokens_from_string(self, text: str) -> int:
"""
Calculates the number of tokens in a string.
Args:
text (str): The string to calculate the number of tokens for.
Returns:
The number of tokens in the string.
Raises:
RuntimeError: If an error occurs while calculating the number of tokens.
"""
try:
if (
self.llm_type == LLM_TYPE.OPENAI
or self.llm_type == LLM_TYPE.AZURE_OPENAI
):
encoding = tiktoken.encoding_for_model(self.encoding_name)
elif self.llm_type == LLM_TYPE.ANTHROPIC:
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("gpt2")
num_tokens = len(encoding.encode(text))
return num_tokens
except Exception as e:
raise RuntimeError(f"Error calculating number of tokens: {str(e)}") from e
def _summarize(self, long_text: str):
"""
Summarizes a long text into a shorter text.
Args:
long_text (str): The long text to summarize.
Returns:
A dictionary containing the summarized text.
Raises:
RuntimeError: If an error occurs while summarizing the text.
"""
try:
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=50
)
split_docs = text_splitter.split_text(long_text)
docs = [Document(page_content=chunk) for chunk in split_docs]
return load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=False
).invoke(docs)
except Exception as e:
raise RuntimeError(f"Error summarizing text: {str(e)}") from e
def _encoding_name(self):
"""
Returns the encoding name for the language model.
Returns:
The encoding name for the language model.
Raises:
ValueError: If the language model type is unsupported.
"""
if self.llm_type == LLM_TYPE.OPENAI:
return os.environ.get("OPENAI_API_MODEL")
elif self.llm_type == LLM_TYPE.AZURE_OPENAI:
return os.environ.get("AZURE_OPENAI_API_MODEL")
elif self.llm_type == LLM_TYPE.ANTHROPIC:
return os.environ.get("ANTHROPIC_API_MODEL")
elif self.llm_type == LLM_TYPE.GOOGLE:
return os.environ.get("GOOGLE_API_MODEL")
else:
raise ValueError(f"Unsupported LLM type: {self.llm_type}")
def _max_tokens(self, model_name: str):
"""
Returns the maximum number of tokens for the specified model.
Args:
model_name (str): The name of the model.
Returns:
The maximum number of tokens for the specified model.
"""
model_configs = {
"openai": {
"gpt-4": {"max_tokens": 8000},
"gpt-4o": {"max_tokens": 128000},
"gpt-4-32k": {"max_tokens": 32000},
"gpt-4-turbo": {"max_tokens": 128000},
"gpt-4-turbo-2024-04-09": {},
"gpt-4-turbo-preview": {},
"gpt-4-0125-preview": {},
"gpt-3.5-turbo": {},
},
"anthropic": {
"claude-3-5-sonnet-20240620": {"max_tokens": 200000},
"claude-3-opus-20240229": {"max_tokens": 200000},
"claude-3-sonnet-20240229": {"max_tokens": 200000},
"claude-3-haiku-20240307": {"max_tokens": 200000},
},
"google": {
"gemini-1.5-pro": {"max_tokens": 128000},
},
}
for _provider, models in model_configs.items():
if model_name in models:
return models[model_name].get("max_tokens", 8000)
return 8000