Compare commits
1 Commits
main
...
revert-47-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52ceb6d010 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,4 +6,3 @@ src/
|
|||||||
eval_results/
|
eval_results/
|
||||||
eval_data/
|
eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.env
|
|
||||||
|
|||||||
49
cli/main.py
49
cli/main.py
@@ -1,8 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
|
||||||
from functools import wraps
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.spinner import Spinner
|
from rich.spinner import Spinner
|
||||||
@@ -749,53 +747,6 @@ def run_analysis():
|
|||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create result directory
|
|
||||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
report_dir = results_dir / "reports"
|
|
||||||
report_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
log_file = results_dir / "message_tool.log"
|
|
||||||
log_file.touch(exist_ok=True)
|
|
||||||
|
|
||||||
def save_message_decorator(obj, func_name):
|
|
||||||
func = getattr(obj, func_name)
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
func(*args, **kwargs)
|
|
||||||
timestamp, message_type, content = obj.messages[-1]
|
|
||||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
|
||||||
with open(log_file, "a") as f:
|
|
||||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def save_tool_call_decorator(obj, func_name):
|
|
||||||
func = getattr(obj, func_name)
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
func(*args, **kwargs)
|
|
||||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
|
||||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
|
||||||
with open(log_file, "a") as f:
|
|
||||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def save_report_section_decorator(obj, func_name):
|
|
||||||
func = getattr(obj, func_name)
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(section_name, content):
|
|
||||||
func(section_name, content)
|
|
||||||
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
|
|
||||||
content = obj.report_sections[section_name]
|
|
||||||
if content:
|
|
||||||
file_name = f"{section_name}.md"
|
|
||||||
with open(report_dir / file_name, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
|
||||||
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
|
||||||
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
|
||||||
|
|
||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
layout = create_layout()
|
layout = create_layout()
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,6 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||||||
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
|
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("llama3.1 local", "llama3.1"),
|
|
||||||
("llama3.2 local", "llama3.2"),
|
("llama3.2 local", "llama3.2"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -212,7 +211,6 @@ def select_deep_thinking_agent(provider) -> str:
|
|||||||
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
|
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("llama3.1 local", "llama3.1"),
|
|
||||||
("qwen3", "qwen3"),
|
("qwen3", "qwen3"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class Toolkit:
|
|||||||
def get_YFin_data(
|
def get_YFin_data(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||||
@@ -145,7 +145,7 @@ class Toolkit:
|
|||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class FinancialSituationMemory:
|
|||||||
self.embedding = "nomic-embed-text"
|
self.embedding = "nomic-embed-text"
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI()
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
||||||
|
|||||||
@@ -628,7 +628,7 @@ def get_YFin_data_window(
|
|||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
):
|
):
|
||||||
|
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
@@ -670,7 +670,7 @@ def get_YFin_data_online(
|
|||||||
def get_YFin_data(
|
def get_YFin_data(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
) -> str:
|
) -> str:
|
||||||
# read in data
|
# read in data
|
||||||
data = pd.read_csv(
|
data = pd.read_csv(
|
||||||
@@ -704,10 +704,10 @@ def get_YFin_data(
|
|||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
def get_stock_news_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI()
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model="gpt-4.1-mini",
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date):
|
|||||||
|
|
||||||
def get_global_news_openai(curr_date):
|
def get_global_news_openai(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI()
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model="gpt-4.1-mini",
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -774,10 +774,10 @@ def get_global_news_openai(curr_date):
|
|||||||
|
|
||||||
def get_fundamentals_openai(ticker, curr_date):
|
def get_fundamentals_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI()
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model="gpt-4.1-mini",
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import os
|
|||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
|
||||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.path.join(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
|
|||||||
Reference in New Issue
Block a user