| |
|
| | import time |
| | import os |
| | import json |
| | from werkzeug.utils import secure_filename |
| | import re |
| | import ast |
| | import sqlite3 |
| | import random |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | from llmware.models import ModelCatalog |
| | from llmware.prompts import Prompt |
| |
|
| | def model_test_run_general(): |
| |
|
| | t0 = time.time() |
| |
|
| | model_name = "llmware/slim-sql-1b-v0" |
| |
|
| | print("update: model_name - ", model_name) |
| |
|
| | custom_hf_model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True) |
| |
|
| | hf_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | |
| | model = ModelCatalog().load_hf_generative_model(custom_hf_model, hf_tokenizer, instruction_following=False, |
| | prompt_wrapper="human_bot") |
| |
|
| | model.temperature = 0.3 |
| | |
| | print("\nupdate: Starting Generative Instruct Custom Fine-tuned Test") |
| |
|
| | t1 = time.time() |
| |
|
| | print("update: time loading model - ", t1 - t0) |
| |
|
| | fp = "" |
| | fn = "sql_test_100_simple_s.jsonl" |
| |
|
| | opened_file = open(os.path.join(fp, fn), "r") |
| |
|
| | prompt_list = [] |
| |
|
| | for i, rows in enumerate(opened_file): |
| | |
| | rows = json.loads(rows) |
| | new_entry = {"question": rows["question"], |
| | "answer": rows["answer"], |
| | "context": rows["context"]} |
| |
|
| | prompt_list.append(new_entry) |
| |
|
| | random.shuffle(prompt_list) |
| |
|
| | total_response_output = [] |
| | perfect_match = 0 |
| |
|
| | for i, entries in enumerate(prompt_list): |
| | prompt = entries["question"] |
| | context = re.sub("[\n\r]","", entries["context"]) |
| | context = re.sub("\s+", " ", context) |
| | context = re.sub("\"", "", context) |
| |
|
| | answer = "" |
| |
|
| | if "answer" in entries: |
| | answer = entries["answer"] |
| |
|
| | output = model.inference(prompt, add_context=context, add_prompt_engineering=True) |
| |
|
| | print("\nupdate: model question - ", prompt) |
| |
|
| | llm_response = re.sub("['\"]", "", output["llm_response"]) |
| | answer = re.sub("['\"]", "", answer) |
| |
|
| | print("update: model response - ", i, llm_response) |
| | print("update: model gold answer - ", answer) |
| |
|
| | if llm_response.strip().lower() == answer.strip().lower(): |
| | perfect_match += 1 |
| | print("update: 100% MATCH") |
| |
|
| | print("update: perfect match accuracy - ", perfect_match / (i+1)) |
| |
|
| | core_output = {"number": i, |
| | "llm_response": output["llm_response"], |
| | "gold_answer": answer, |
| | "prompt": prompt, |
| | "usage": output["usage"]} |
| |
|
| | total_response_output.append(core_output) |
| |
|
| | t2 = time.time() |
| |
|
| | print("update: total processing time: ", t2-t1) |
| |
|
| | return total_response_output |
| |
|
| | output = model_test_run_general() |
| |
|