From 43aa9c5d09baac7a4fda6d0fce82e9d21b0f3532 Mon Sep 17 00:00:00 2001 From: Max Wong Date: Thu, 26 Jun 2025 00:27:01 -0400 Subject: [PATCH] Local Ollama (#53) - Fix typo 'Start' 'End' - Add llama3.1 selection - Use 'quick_think_llm' model instead of hard-coding GPT --- cli/utils.py | 2 ++ tradingagents/agents/utils/agent_utils.py | 4 ++-- tradingagents/agents/utils/memory.py | 2 +- tradingagents/dataflows/interface.py | 16 ++++++++-------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index d387336..7b9682a 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -150,6 +150,7 @@ def select_shallow_thinking_agent(provider) -> str: ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ], "ollama": [ + ("llama3.1 local", "llama3.1"), ("llama3.2 local", "llama3.2"), ] } @@ -211,6 +212,7 @@ def select_deep_thinking_agent(provider) -> str: ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ], "ollama": [ + ("llama3.1 local", "llama3.1"), ("qwen3", "qwen3"), ] } diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index b7313b7..0b07f04 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -124,7 +124,7 @@ class Toolkit: def get_YFin_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: """ Retrieve the stock price data for a given ticker symbol from Yahoo Finance. @@ -145,7 +145,7 @@ class Toolkit: def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: """ Retrieve the stock price data for a given ticker symbol from Yahoo Finance. diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index f341576..69b8ab8 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -9,7 +9,7 @@ class FinancialSituationMemory: self.embedding = "nomic-embed-text" else: self.embedding = "text-embedding-3-small" - self.client = OpenAI() + self.client = OpenAI(base_url=config["backend_url"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index a095294..7fffbb4 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -628,7 +628,7 @@ def get_YFin_data_window( def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ): datetime.strptime(start_date, "%Y-%m-%d") @@ -670,7 +670,7 @@ def get_YFin_data_online( def get_YFin_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: # read in data data = pd.read_csv( @@ -704,10 +704,10 @@ def get_YFin_data( def get_stock_news_openai(ticker, curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date): def get_global_news_openai(curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -774,10 +774,10 @@ def get_global_news_openai(curr_date): def get_fundamentals_openai(ticker, curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system",