Source code for alter_ego.agents.GPTThread

from typing import Any
import openai
import os
import sys
import time

from alter_ego.agents import APIThread
import alter_ego.utils

client = openai.OpenAI(api_key="")


[docs] class GPTThread(APIThread): """ Class representing a GPT-3 or GPT-4 Thread. """ COST_PER_TOKEN = { "gpt-3.5-turbo": {"prompt": 0.0015 / 1000, "completion": 0.002 / 1000}, "gpt-4": {"prompt": 0.03 / 1000, "completion": 0.06 / 1000}, } def __init__(self, model, temperature, *args, **kwargs) -> None: if "extra_for_module" in kwargs: for k, v in kwargs["extra_for_module"]: setattr(openai, k, v) if "extra_for_client" in kwargs: for k, v in kwargs["extra_for_client"]: setattr(client, k, v) super().__init__(*args, model=model, temperature=temperature, **kwargs)
[docs] def get_api_key(self) -> str: """ Retrieve the OpenAI API key. :return: The OpenAI API key. :rtype: str :raises ValueError: If API key is not found. """ if "OPENAI_KEY" in os.environ: return os.environ["OPENAI_KEY"] elif os.path.exists("openai_key"): return alter_ego.utils.from_file("openai_key") elif os.path.exists("api_key"): return alter_ego.utils.from_file("api_key") else: raise ValueError( "If not specified within the GPTThread constructor (argument api_key), OpenAI API key must be specified in the environment variable OPENAI_KEY, or any of the files openai_key or api_key." )
[docs] def cost(self) -> float: """ Calculate the cost based on the model used and token usage. :return: The cost for the model. :rtype: float :raises NotImplementedError: If the model is not supported. """ cost_per_token = self.COST_PER_TOKEN.get(self.model) if not cost_per_token: raise NotImplementedError return sum( cost_per_token["prompt"] * entry["usage"]["prompt_tokens"] + cost_per_token["completion"] * entry["usage"]["completion_tokens"] for entry in self.log if not isinstance(entry, Exception) )
[docs] def send( self, role: str, message: str, max_tokens: int = 500, **kwargs: Any ) -> str: """ Submit the user message, get the response from the model, and memorize it. :param role: Role of the sender ("user"). :type role: str :param message: The user's message to submit. :type message: str :param max_tokens: Maximum number of tokens for the model to generate. :type max_tokens: int :keyword kwargs: Additional keyword arguments. :type kwargs: Any :return: The model's response. :rtype: str """ if role == "user": time.sleep(self.delay) llm_out = self.get_model_output(message, max_tokens) response = llm_out.choices[0].message.content self.memorize("assistant", response) return response
[docs] def get_model_output(self, message: str, max_tokens: int) -> str: """ Get the model output for the given message. :param message: The user's message. :type message: str :param max_tokens: Maximum number of tokens for the model to generate. :type max_tokens: int :return: The model output. :rtype: str :raises RuntimeError: If maximum number of retries is exceeded. """ client.api_key = self.api_key retries = 0 while retries <= self.max_retries: try: if self.verbose: print("+", end="", file=sys.stderr, flush=True) llm_out = client.chat.completions.create( model=self.model, messages=self.history, max_tokens=max_tokens, n=1, stop=None, temperature=self.temperature, ) self.log.append(llm_out) return llm_out except Exception as e: retries += 1 self.log.append(e) time.sleep(1) raise RuntimeError(f"max_retries ({self.max_retries}) exceeded for {self}.")