Local Ollama (#53)

- Fix typo 'Start' 'End'
- Add llama3.1 selection
- Use 'quick_think_llm' model instead of hard-coding GPT
This commit is contained in:
Max Wong
2025-06-26 00:27:01 -04:00
committed by GitHub
parent 26c5ba5a78
commit 43aa9c5d09
4 changed files with 13 additions and 11 deletions

View File

@@ -150,6 +150,7 @@ def select_shallow_thinking_agent(provider) -> str:
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"), ("llama3.2 local", "llama3.2"),
] ]
} }
@@ -211,6 +212,7 @@ def select_deep_thinking_agent(provider) -> str:
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"), ("qwen3", "qwen3"),
] ]
} }

View File

@@ -124,7 +124,7 @@ class Toolkit:
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
@@ -145,7 +145,7 @@ class Toolkit:
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.

View File

@@ -9,7 +9,7 @@ class FinancialSituationMemory:
self.embedding = "nomic-embed-text" self.embedding = "nomic-embed-text"
else: else:
self.embedding = "text-embedding-3-small" self.embedding = "text-embedding-3-small"
self.client = OpenAI() self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@@ -628,7 +628,7 @@ def get_YFin_data_window(
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
): ):
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
@@ -670,7 +670,7 @@ def get_YFin_data_online(
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
# read in data # read in data
data = pd.read_csv( data = pd.read_csv(
@@ -704,10 +704,10 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",
@@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date): def get_global_news_openai(curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",
@@ -774,10 +774,10 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",