adding files for bedrock-converse
This commit is contained in:
@@ -42,7 +42,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -143,7 +143,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -172,23 +172,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from anthropic import Anthropic\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import boto3\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"client = Anthropic()\n",
|
||||
"def get_model_response(prompt):\n",
|
||||
"\n",
|
||||
"def get_model_response(prompt, model_name):\n",
|
||||
" response = client.messages.create(\n",
|
||||
" model=model_name,\n",
|
||||
" max_tokens=200,\n",
|
||||
" messages=[{'role': 'user', 'content': prompt}]\n",
|
||||
" bedrock_client = boto3.client(service_name='bedrock-runtime', region_name=\"us-west-2\")\n",
|
||||
" model_id = \"anthropic.claude-3-haiku-20240307-v1:0\"\n",
|
||||
"\n",
|
||||
" # Send the message.\n",
|
||||
" response = bedrock_client.converse(\n",
|
||||
" modelId=model_id,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": [{\"text\":prompt}]}]\n",
|
||||
" )\n",
|
||||
" return response.content[0].text\n",
|
||||
"\n",
|
||||
" return response[\"output\"][\"message\"][\"content\"][0][\"text\"]\n",
|
||||
"\n",
|
||||
"def calculate_accuracy(eval_data, model_responses):\n",
|
||||
" correct_predictions = 0\n",
|
||||
@@ -203,9 +204,8 @@
|
||||
" \n",
|
||||
" return correct_predictions / total_predictions\n",
|
||||
"\n",
|
||||
"def evaluate_prompt(prompt_func, eval_data, model_name):\n",
|
||||
" print(f\"Evaluating with model: {model_name}\")\n",
|
||||
" model_responses = [get_model_response(prompt_func(item['complaint']), model_name) for item in eval_data]\n",
|
||||
"def evaluate_prompt(prompt_func, eval_data):\n",
|
||||
" model_responses = [get_model_response(prompt_func(item['complaint'])) for item in eval_data]\n",
|
||||
" accuracy = calculate_accuracy(eval_data, model_responses)\n",
|
||||
" \n",
|
||||
" print(f\"Accuracy: {accuracy:.2%}\")\n",
|
||||
@@ -241,14 +241,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating with model: claude-3-haiku-20240307\n",
|
||||
"Accuracy: 85.00%\n",
|
||||
"\n",
|
||||
"Complaint: The app crashes every time I try to upload a photo\n",
|
||||
@@ -293,8 +292,7 @@
|
||||
"\n",
|
||||
"Complaint: The app is crashing and my phone is overheating\n",
|
||||
"Golden Answer: ['Software Bug', 'Hardware Malfunction']\n",
|
||||
"Model Response: Hardware Malfunction\n",
|
||||
"Software Bug\n",
|
||||
"Model Response: Software Bug, Hardware Malfunction\n",
|
||||
"\n",
|
||||
"Complaint: I can't remember my password!\n",
|
||||
"Golden Answer: ['User Error']\n",
|
||||
@@ -306,11 +304,11 @@
|
||||
"\n",
|
||||
"Complaint: I think I installed something incorrectly, now my computer won't start at all\n",
|
||||
"Golden Answer: ['User Error', 'Hardware Malfunction']\n",
|
||||
"Model Response: User Error, Hardware Malfunction\n",
|
||||
"Model Response: User Error\n",
|
||||
"\n",
|
||||
"Complaint: Your service is down, and I urgently need a feature to batch process files\n",
|
||||
"Golden Answer: ['Service Outage', 'Feature Request']\n",
|
||||
"Model Response: Feature Request, Service Outage\n",
|
||||
"Model Response: Service Outage, Feature Request\n",
|
||||
"\n",
|
||||
"Complaint: The graphics card is making weird noises\n",
|
||||
"Golden Answer: ['Hardware Malfunction']\n",
|
||||
@@ -339,13 +337,13 @@
|
||||
"0.85"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"evaluate_prompt(basic_prompt, eval_data, model_name=\"claude-3-haiku-20240307\")"
|
||||
"evaluate_prompt(basic_prompt, eval_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -362,7 +360,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -425,14 +423,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating with model: claude-3-haiku-20240307\n",
|
||||
"Accuracy: 100.00%\n",
|
||||
"\n",
|
||||
"Complaint: The app crashes every time I try to upload a photo\n",
|
||||
@@ -522,13 +519,13 @@
|
||||
"1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 80,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"evaluate_prompt(improved_prompt, eval_data, model_name=\"claude-3-haiku-20240307\")"
|
||||
"evaluate_prompt(improved_prompt, eval_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user