Merge pull request #25 from maxer137/main
Add support for other backends, such as OpenRouter and Ollama
This commit is contained in:
@@ -1,19 +1,23 @@
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name):
|
||||
self.client = OpenAI()
|
||||
def __init__(self, name, config):
|
||||
if config["openai_backend"] == "http://localhost:11434/v1":
|
||||
self.embedding = "nomic-embed-text"
|
||||
else:
|
||||
self.embedding = "text-embedding-ada-002"
|
||||
self.client = OpenAI(base_url=config["openai_backend"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
model="text-embedding-ada-002", input=text
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
|
||||
@@ -703,10 +703,11 @@ def get_YFin_data(
|
||||
|
||||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
client = OpenAI()
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["openai_backend"])
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -737,10 +738,11 @@ def get_stock_news_openai(ticker, curr_date):
|
||||
|
||||
|
||||
def get_global_news_openai(curr_date):
|
||||
client = OpenAI()
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["openai_backend"])
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -771,10 +773,11 @@ def get_global_news_openai(curr_date):
|
||||
|
||||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
client = OpenAI()
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["openai_backend"])
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
|
||||
@@ -10,6 +10,7 @@ DEFAULT_CONFIG = {
|
||||
# LLM settings
|
||||
"deep_think_llm": "o4-mini",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"openai_backend": "https://api.openai.com/v1",
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
||||
@@ -55,18 +55,18 @@ class TradingAgentsGraph:
|
||||
)
|
||||
|
||||
# Initialize LLMs
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"])
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["openai_backend"],)
|
||||
self.quick_thinking_llm = ChatOpenAI(
|
||||
model=self.config["quick_think_llm"], temperature=0.1
|
||||
model=self.config["quick_think_llm"], temperature=0.1, base_url=self.config["openai_backend"],
|
||||
)
|
||||
self.toolkit = Toolkit(config=self.config)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory")
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory")
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory")
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory")
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory")
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
|
||||
Reference in New Issue
Block a user