1 Commits

Author SHA1 Message Date
Yijia Xiao
52ceb6d010 Revert "Docker support and Ollama support (#47)"
This reverts commit 78ea029a0b.
2025-06-26 00:00:39 -04:00
7 changed files with 11 additions and 64 deletions

1
.gitignore vendored
View File

@@ -6,4 +6,3 @@ src/
eval_results/ eval_results/
eval_data/ eval_data/
*.egg-info/ *.egg-info/
.env

View File

@@ -1,8 +1,6 @@
from typing import Optional from typing import Optional
import datetime import datetime
import typer import typer
from pathlib import Path
from functools import wraps
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.spinner import Spinner from rich.spinner import Spinner
@@ -749,53 +747,6 @@ def run_analysis():
[analyst.value for analyst in selections["analysts"]], config=config, debug=True [analyst.value for analyst in selections["analysts"]], config=config, debug=True
) )
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1]
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
def save_report_section_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(section_name, content):
func(section_name, content)
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
content = obj.report_sections[section_name]
if content:
file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f:
f.write(content)
return wrapper
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
# Now start the display layout # Now start the display layout
layout = create_layout() layout = create_layout()

View File

@@ -150,7 +150,6 @@ def select_shallow_thinking_agent(provider) -> str:
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"), ("llama3.2 local", "llama3.2"),
] ]
} }
@@ -212,7 +211,6 @@ def select_deep_thinking_agent(provider) -> str:
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"), ("qwen3", "qwen3"),
] ]
} }

View File

@@ -124,7 +124,7 @@ class Toolkit:
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
@@ -145,7 +145,7 @@ class Toolkit:
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.

View File

@@ -9,7 +9,7 @@ class FinancialSituationMemory:
self.embedding = "nomic-embed-text" self.embedding = "nomic-embed-text"
else: else:
self.embedding = "text-embedding-3-small" self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"]) self.client = OpenAI()
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@@ -628,7 +628,7 @@ def get_YFin_data_window(
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
): ):
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
@@ -670,7 +670,7 @@ def get_YFin_data_online(
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
) -> str: ) -> str:
# read in data # read in data
data = pd.read_csv( data = pd.read_csv(
@@ -704,10 +704,10 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI(base_url=config["backend_url"]) client = OpenAI()
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model="gpt-4.1-mini",
input=[ input=[
{ {
"role": "system", "role": "system",
@@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date): def get_global_news_openai(curr_date):
config = get_config() config = get_config()
client = OpenAI(base_url=config["backend_url"]) client = OpenAI()
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model="gpt-4.1-mini",
input=[ input=[
{ {
"role": "system", "role": "system",
@@ -774,10 +774,10 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI(base_url=config["backend_url"]) client = OpenAI()
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model="gpt-4.1-mini",
input=[ input=[
{ {
"role": "system", "role": "system",

View File

@@ -2,7 +2,6 @@ import os
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join( "data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),