diff --git a/cli/main.py b/cli/main.py index e7bed4e..dbc8eb6 100644 --- a/cli/main.py +++ b/cli/main.py @@ -444,20 +444,29 @@ def get_user_selections(): ) selected_research_depth = select_research_depth() - # Step 5: Thinking agents + # Step 5: OpenAI backend console.print( create_question_box( - "Step 5: Thinking Agents", "Select your thinking agents for analysis" + "Step 5: OpenAI backend", "Select which service to talk to" ) ) - selected_shallow_thinker = select_shallow_thinking_agent() - selected_deep_thinker = select_deep_thinking_agent() + selected_openai_backend = select_openai_backend() + + # Step 6: Thinking agents + console.print( + create_question_box( + "Step 6: Thinking Agents", "Select your thinking agents for analysis" + ) + ) + selected_shallow_thinker = select_shallow_thinking_agent(selected_openai_backend) + selected_deep_thinker = select_deep_thinking_agent(selected_openai_backend) return { "ticker": selected_ticker, "analysis_date": analysis_date, "analysts": selected_analysts, "research_depth": selected_research_depth, + "openai_backend": selected_openai_backend, "shallow_thinker": selected_shallow_thinker, "deep_thinker": selected_deep_thinker, } @@ -694,6 +703,7 @@ def run_analysis(): config["max_risk_discuss_rounds"] = selections["research_depth"] config["quick_think_llm"] = selections["shallow_thinker"] config["deep_think_llm"] = selections["deep_thinker"] + config["openai_backend"] = selections["openai_backend"] # Initialize the graph graph = TradingAgentsGraph( diff --git a/cli/utils.py b/cli/utils.py index c386525..29c9324 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -122,22 +122,32 @@ def select_research_depth() -> int: return choice -def select_shallow_thinking_agent() -> str: +def select_shallow_thinking_agent(backend) -> str: """Select shallow thinking llm engine using an interactive selection.""" # Define shallow thinking llm engine options with their corresponding model names - SHALLOW_AGENT_OPTIONS = [ - ("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"), - ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), - ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), - ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), - ] + SHALLOW_AGENT_OPTIONS = { + "https://api.openai.com/v1": [ + ("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"), + ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), + ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), + ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), + ], + "https://openrouter.ai/api/v1": [ + ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), + ("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"), + ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), + ], + "http://localhost:11434/v1": [ + ("llama3.2 local", "llama3.2"), + ] + } choice = questionary.select( "Select Your [Quick-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS + for display, value in SHALLOW_AGENT_OPTIONS[backend] ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -158,25 +168,34 @@ def select_shallow_thinking_agent() -> str: return choice -def select_deep_thinking_agent() -> str: +def select_deep_thinking_agent(backend) -> str: """Select deep thinking llm engine using an interactive selection.""" # Define deep thinking llm engine options with their corresponding model names - DEEP_AGENT_OPTIONS = [ - ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), - ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), - ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), - ("o4-mini - Specialized reasoning model (compact)", "o4-mini"), - ("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"), - ("o3 - Full advanced reasoning model", "o3"), - ("o1 - Premier reasoning and problem-solving model", "o1"), - ] - + DEEP_AGENT_OPTIONS = { + "https://api.openai.com/v1": [ + ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), + ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), + ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), + ("o4-mini - Specialized reasoning model (compact)", "o4-mini"), + ("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"), + ("o3 - Full advanced reasoning model", "o3"), + ("o1 - Premier reasoning and problem-solving model", "o1"), + ], + "https://openrouter.ai/api/v1": [ + ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), + ("deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), + ], + "http://localhost:11434/v1": [ + ("qwen3", "qwen3"), + ] + } + choice = questionary.select( "Select Your [Deep-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS + for display, value in DEEP_AGENT_OPTIONS[backend] ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -193,3 +212,35 @@ def select_deep_thinking_agent() -> str: exit(1) return choice + +def select_openai_backend() -> str: + """Select the OpenAI api url using interactive selection.""" + + # Define OpenAI api options with their corresponding endpoints + OPENAI_BASE_URLS = [ + ("OpenAI - Requires an OpenAPI Key", "https://api.openai.com/v1"), + ("Openrouter - Requires an OpenRouter API Key", "https://openrouter.ai/api/v1"), + ("Ollama - Local", "http://localhost:11434/v1") + ] + + choice = questionary.select( + "Select your [OpenAI endpoint]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in OPENAI_BASE_URLS + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + exit(1) + + return choice diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index a1934bd..cdd9e2e 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,19 +1,23 @@ import chromadb from chromadb.config import Settings from openai import OpenAI -import numpy as np class FinancialSituationMemory: - def __init__(self, name): - self.client = OpenAI() + def __init__(self, name, config): + if config["openai_backend"] == "http://localhost:11434/v1": + self.embedding = "nomic-embed-text" + else: + self.embedding = "text-embedding-ada-002" + self.client = OpenAI(base_url=config["openai_backend"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" + response = self.client.embeddings.create( - model="text-embedding-ada-002", input=text + model=self.embedding, input=text ) return response.data[0].embedding diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index e0c0b70..f151a74 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -703,10 +703,11 @@ def get_YFin_data( def get_stock_news_openai(ticker, curr_date): - client = OpenAI() + config = get_config() + client = OpenAI(base_url=config["openai_backend"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -737,10 +738,11 @@ def get_stock_news_openai(ticker, curr_date): def get_global_news_openai(curr_date): - client = OpenAI() + config = get_config() + client = OpenAI(base_url=config["openai_backend"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -771,10 +773,11 @@ def get_global_news_openai(curr_date): def get_fundamentals_openai(ticker, curr_date): - client = OpenAI() + config = get_config() + client = OpenAI(base_url=config["openai_backend"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 5bb2548..da1d246 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -10,6 +10,7 @@ DEFAULT_CONFIG = { # LLM settings "deep_think_llm": "o4-mini", "quick_think_llm": "gpt-4o-mini", + "openai_backend": "https://api.openai.com/v1", # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index bbd4507..5d0ce6b 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -55,18 +55,18 @@ class TradingAgentsGraph: ) # Initialize LLMs - self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"]) + self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["openai_backend"],) self.quick_thinking_llm = ChatOpenAI( - model=self.config["quick_think_llm"], temperature=0.1 + model=self.config["quick_think_llm"], temperature=0.1, base_url=self.config["openai_backend"], ) self.toolkit = Toolkit(config=self.config) # Initialize memories - self.bull_memory = FinancialSituationMemory("bull_memory") - self.bear_memory = FinancialSituationMemory("bear_memory") - self.trader_memory = FinancialSituationMemory("trader_memory") - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory") - self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory") + self.bull_memory = FinancialSituationMemory("bull_memory", self.config) + self.bear_memory = FinancialSituationMemory("bear_memory", self.config) + self.trader_memory = FinancialSituationMemory("trader_memory", self.config) + self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) + self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes()