diff --git a/README.md b/README.md index 8ab6e0c..7c41f2a 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,31 @@ The file of our checkpoint generated after we fine-tuned it is upload to [Google You can see the prompt we tried from the file `prompt templates`. Up to now, we are using the file `prompt templates/license_master_v4.json `. +# Testing Fine-tuned Modelsc +To test the fine-tuned model with a test dataset of question-answer pairs, where the question is a segment of license content and the answer is a human-annotated risk rating (high, medium, low risk), please follow these steps: + + +1. install necessary libraries: +``` +pip install transformers torch scikit-learn +``` +2. prepate the test data as follows: +``` +[ + { + "instruction": ".......", + "input": ".....", + "output": "......" + }, + ... +] +``` +3. run the test script +``` +python test.py +``` +This script will load the fine-tuned model and test dataset, generate answers for each question, calculate the semantic similarity between the predicted and standard answers using BERT, and finally compute the accuracy based on a predefined similarity threshold which can be modified by users. + # Computing resources NVIDIA GeForce RTX 3090 @@ -101,6 +126,8 @@ Despite the smaller scale of LicenseGPT, its tailored fine-tuning process allowe The traditional way means software IP lawyers should manually invest a substantial amount of time consulting resources and legal literature, resulting in context-specific outcomes that can take anywhere from 10 minutes to an hour, with an average of 30 minutes per case. In stark contrast, the proposed LicenseGPT's inference time spans from a swift 1 millisecond to 70 seconds at the upper end, with an average of just 10 seconds. With the help of LicenseGPT, a substantial reduction translates to an average review time for software IP lawyers of 10 minutes, representing a notable efficiency gain and freeing up valuable time for legal professionals. + + # Case Study Details in CaseStudy.md diff --git a/test.py b/test.py new file mode 100644 index 0000000..c9ab08f --- /dev/null +++ b/test.py @@ -0,0 +1,56 @@ +import json +from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertModel +from sklearn.metrics.pairwise import cosine_similarity +import torch + +# 加载微调后的模型 +model_path = 'path/to/your/fine-tuned-model' +model = AutoModelForSequenceClassification.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +# 加载BERT模型用于语义相似度比较 +bert_model_name = 'bert-base-uncased' +bert_model = BertModel.from_pretrained(bert_model_name) +bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) + +# 加载测试数据集 +test_data_path = 'path/to/your/test-data.json' + +with open(test_data_path, 'r', encoding='utf-8') as f: + test_data = json.load(f) + +# 生成预测结果 +predicted_answers = [] + +for item in test_data: + question = item['input'] + inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True, max_length=512) + outputs = model(**inputs) + predicted_answer = tokenizer.decode(torch.argmax(outputs.logits, dim=-1), skip_special_tokens=True) + predicted_answers.append(predicted_answer) + +# 计算语义相似度 +def get_bert_embedding(text, tokenizer, model): + inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) + with torch.no_grad(): + outputs = model(**inputs) + return outputs.last_hidden_state.mean(dim=1) + +similarities = [] + +for idx, item in enumerate(test_data): + standard_answer = item['output'] + predicted_answer = predicted_answers[idx] + + standard_embedding = get_bert_embedding(standard_answer, bert_tokenizer, bert_model) + predicted_embedding = get_bert_embedding(predicted_answer, bert_tokenizer, bert_model) + + similarity = cosine_similarity(standard_embedding.numpy(), predicted_embedding.numpy())[0][0] + similarities.append(similarity) + +# 计算准确率 +threshold = 0.8 +correct_predictions = [1 if sim >= threshold else 0 for sim in similarities] +accuracy = sum(correct_predictions) / len(correct_predictions) + +print(f"Accuracy: {accuracy * 100:.2f}%")