Source code for alter_ego.agents.TextSynthThread
from typing import Any, Dict, Optional
import json
import os
import requests
import sys
import time
from alter_ego.agents import APIThread
[docs]
class TextSynthThread(APIThread):
"""
Class representing a TextSynth Thread.
"""
def __init__(self, **kwargs: Any):
"""
Initialize the TextSynthThread.
:keyword kwargs: Additional keyword arguments.
:type kwargs: Any
"""
# defaults
self.endpoint = "https://api.textsynth.com/v1/engines/falcon_40B-chat/chat"
self.temperature = 1.0
super().__init__(**kwargs)
[docs]
def ts_data(self) -> Dict[str, Any]:
"""
Prepare the data for TextSynth API call.
:return: Data to be sent in the API request.
:rtype: Dict[str, Any]
:raises ValueError: If an invalid history item is encountered.
"""
system_set = False
next_role = 1
data = dict(messages=[])
for item in self._history:
if item["role"] == "system" and not system_set:
data["system"] = item["content"]
system_set = True
elif item["role"] == "user" and next_role == 1:
data["messages"].append(item["content"])
next_role = 2
elif item["role"] == "assistant" and next_role == 2:
data["messages"].append(item["content"])
next_role = 1
else:
raise ValueError(f"The following history item is invalid: {item}")
return data
[docs]
def send(
self,
role: str,
message: str,
max_tokens: int = 500,
extra_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""
Submit the user message, receive the model's response, and memorize it.
:param role: Role of the sender ("user").
:type role: str
:param message: The user's message.
:type message: str
:param max_tokens: Maximum number of tokens for the model to generate.
:type max_tokens: int
:param extra_params: Additional parameters for the model.
:type extra_params: Optional[Dict[str, Any]]
:keyword kwargs: Additional keyword arguments.
:type kwargs: Any
:return: The model's response.
:rtype: str
"""
if role == "user":
time.sleep(self.delay)
llm_out = self.get_model_output(message, max_tokens, extra_params)
response = llm_out["text"]
self.memorize("assistant", response)
return response
[docs]
def get_model_output(
self,
message: str,
max_tokens: int,
extra_params: Optional[Dict[str, Any]] = None,
) -> Any:
"""
Get the model output for the given message.
:param message: The user's message.
:type message: str
:param max_tokens: Maximum number of tokens for the model to generate.
:type max_tokens: int
:param extra_params: Additional parameters for the model.
:type extra_params: Optional[Dict[str, Any]]
:return: The model output.
:rtype: Any
:raises RuntimeError: If the maximum number of retries is exceeded.
"""
retries = 0
while retries <= self.max_retries:
try:
if self.verbose:
print("+", end="", file=sys.stderr, flush=True)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
params = {
"max_tokens": max_tokens,
"temperature": self.temperature,
} | (extra_params if extra_params is not None else {})
rq = requests.post(
self.endpoint,
headers=headers,
data=json.dumps(self.ts_data() | params),
)
llm_out = rq.json()
self.log.append(llm_out)
return llm_out
except Exception as e:
retries += 1
self.log.append(e)
time.sleep(1)
raise RuntimeError(f"max_retries ({self.max_retries}) exceeded for {self}.")