Source code for alter_ego.structure

from abc import ABC, abstractmethod
from typing import Callable, Dict, Iterator, List, Set, Tuple, Union, Any
from jinja2 import Environment, StrictUndefined
from alter_ego.structure.Relay import Relay
import copy
import uuid
import os
import pickle
import json

VALID_ROLES = ["system", "user", "assistant"]


[docs] class Thread(ABC): """ Abstract base class representing a Thread. """ def __init__(self, **params: Any) -> None: """ Initialize a Thread instance. :param params: Arbitrary keyword parameters. """ # Initialize instance variables self.__dict__ |= params self.id: uuid.UUID = uuid.uuid4() self.metadata: Dict[str, Any] = {} self._history: List[Dict[str, str]] = [] self.tainted: bool = False self.convo = None # Will be assigned later self.history_hooks: Set[Callable] = set() self.choices: List[Any] = [] self.env: Environment = Environment(undefined=StrictUndefined) def __repr__(self) -> str: """ :return: String representation of the Thread. """ return f"<{self.__class__.__name__}/{self.name if 'name' in self.__dict__ else str(self.id)[0:8]}>" @property def history(self) -> List[Dict[str, str]]: """ :return: Deep copy of message history. """ return copy.deepcopy(self._history)
[docs] def memorize(self, role: str, message: str) -> None: """ Memorizes a message and its associated role. :param role: Role of the message ('system', 'user', or 'assistant'). :param message: Message content. :raises ValueError: If the role is invalid or if a system message already exists. """ if role not in VALID_ROLES: raise ValueError(f'Invalid role "{role}".') if role == "system" and any(item["role"] == "system" for item in self.history): raise ValueError(f"System message already set.") self._history.append({"role": role, "content": message}) for f in self.history_hooks: f(self)
[docs] def prepare(self, template: str, **extra: Any) -> str: """ Prepare a template with additional parameters. :param template: Template string. :param extra: Additional parameters to inject into the template. :return: Rendered template string. """ template = self.env.from_string(template) return template.render(**extra, **self.__dict__)
[docs] def save( self, subdir: str = ".", outdir: str = "out", full_save: bool = True ) -> None: """ Saves the current Thread into a file. :param subdir: The sub-directory to save the file in. :param outdir: The main directory to save the file in. :param full_save: Whether to save as pickle (True) or JSON (False). :raises ValueError: If the Thread is not part of a Conversation. """ if self.convo is None: target_dir = f"{outdir}/{subdir}" else: target_dir = f"{outdir}/{subdir}/{self.convo.id}" os.makedirs(target_dir, exist_ok=True) outfile = ( f"{target_dir}/{self.id}.pkl" if full_save else f"{target_dir}/{self.id}.json" ) mode = "wb" if full_save else "w" with open(outfile, mode) as fp: if full_save: pickle.dump(self, fp) else: json.dump(dict(history=self.history, metadata=self.metadata), fp)
[docs] def cost(self) -> float: """ Computes and returns the cost associated with the Thread. :returns: The cost, 0.0 for this base implementation. Adjust in subclasses. """ return 0.0
[docs] def system(self, message: str, **kwargs: Any) -> Any: """ Sends a system-level message. :param message: The message to send. :param kwargs: Additional keyword arguments for message preparation. :returns: Return value from the send method. """ self.memorize("system", m := self.prepare(message, **kwargs)) retval = self.send("system", m, **kwargs) return retval
[docs] def user(self, message: str, **kwargs: Any) -> Any: """ Sends a user-level message. :param message: The message to send. :param kwargs: Additional keyword arguments for message preparation. :returns: Return value from the send method. """ self.memorize("user", m := self.prepare(message, **kwargs)) retval = self.send("user", m, **kwargs) return retval
[docs] def assistant(self, message: str, **kwargs: Any) -> Any: """ Sends an assistant-level message. :param message: The message to send. :param kwargs: Additional keyword arguments for message preparation. :returns: Return value from the send method. """ self.memorize("assistant", m := self.prepare(message, **kwargs)) retval = self.send("assistant", m, **kwargs) return retval
[docs] def submit(self, message: str, **kwargs: Any) -> Any: """ Submits a message as a user and sends it after preparation. :param message: The message to submit. :param kwargs: Additional keyword arguments for message preparation. :returns: Return value from the send method. """ self.memorize("user", m := self.prepare(message, **kwargs)) retval = self.send("user", m, **kwargs) return retval
[docs] @abstractmethod def send(self, role: str, message: str, **kwargs: Any) -> Any: """ Abstract method that must be implemented by subclasses to send messages. :param role: Role of the sender, can be 'system', 'user', or 'assistant'. :param message: The message to be sent. :param kwargs: Additional keyword arguments. :returns: Implementation dependent. """ pass
[docs] class Conversation: """ Class encapsulating a Conversation consisting of multiple Threads. """ def __init__(self, *threads: Thread, **named_threads: Thread) -> None: """ Initialize a Conversation object. :param threads: Thread instances as positional arguments. :param named_threads: Thread instances as named arguments. :raises ValueError: If both 'threads' and 'named_threads' are used, or none of them are used. """ if (len(threads) > 0 and len(named_threads) > 0) or ( len(threads) == 0 and len(named_threads) == 0 ): raise ValueError( "You may only use 'threads' or 'named_threads', not both, not neither." ) if len(threads) > 0: self.threads = tuple(threads) else: self.threads = tuple(named_threads.values()) self.__dict__ |= named_threads self.now = 1 self.all = Relay(self, self.threads) self.id = uuid.uuid4() for thread in self.threads: thread.convo = self if len(self.threads) == 2: thread.other = ( self.threads[1] if thread == self.threads[0] else self.threads[0] ) elif len(self.threads) > 2: thread.others = [t for t in self.threads if t != thread] def __iter__(self) -> Iterator[Thread]: """ Create an iterator for traversing through the Threads in this Conversation. :yields: Each Thread in the Conversation. """ for j, thread in enumerate(self.threads, 1): self.now = j yield thread