64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
import re
|
||
|
||
from langchain_core.output_parsers import StrOutputParser
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain_openai import ChatOpenAI
|
||
from loguru import logger
|
||
import traceback
|
||
from config.config import load_config
|
||
|
||
class_type_config =[
|
||
"本期开展了小袋鼠整合主题课程:(语言、社会、科学、健康、艺术)、生活数学;特色课程(英语、体能、美工、篮球)。",
|
||
"本学期开展了柏克莱主题课程(语言、社会、科学、艺术、健康);英语及特色课程(体能、舞蹈、美工、魔力猴、足球、国学)。",
|
||
"本学期开展了双木桥主题课程(图说汉字、妙趣汉音、情智阅读、麦斯思维、专注力训练);英语及特色课程(体能、舞蹈、美工、魔力猴、足球、国学)。"
|
||
]
|
||
|
||
def generate_comment(name, age_group, traits,sex):
|
||
"""
|
||
生成评语
|
||
:param name: 学生姓名
|
||
:param age_group: 所在班级
|
||
:param traits: 表现特征
|
||
:param sex: 性别
|
||
:return: 评语
|
||
"""
|
||
# 1. 加载配置文件
|
||
try:
|
||
config = load_config("config.toml")
|
||
except Exception as e:
|
||
logger.error(f"配置文件获取失败: {str(e)}")
|
||
# 打印详细报错位置,方便调试
|
||
logger.error(traceback.format_exc())
|
||
return "配置文件加载失败,请检查文件路径和内容。"
|
||
ai_config = config["ai"]
|
||
llm = ChatOpenAI(
|
||
base_url=ai_config["api_url"],
|
||
api_key=ai_config["api_key"],
|
||
model=ai_config["model"],
|
||
temperature=0.7,
|
||
)
|
||
# 2. 构建 Prompt Template
|
||
prompt = ChatPromptTemplate.from_messages([
|
||
("system", ai_config["prompt"]),
|
||
("human", "学生姓名:{name}\n所在班级:{age_group}\n性别:{sex}\n表现特征:{traits}\n\n请开始撰写评语:")
|
||
])
|
||
|
||
# 3. 组装链 (Prompt -> Model -> OutputParser)
|
||
chain = prompt | llm | StrOutputParser()
|
||
|
||
# 4. 执行
|
||
try:
|
||
comment = chain.invoke({
|
||
"name": name,
|
||
"age_group": age_group,
|
||
"traits": traits,
|
||
"sex": sex,
|
||
"class_type": class_type_config[(config.get("class_type", 0))],
|
||
})
|
||
cleaned_text = re.sub(r'\s+', '', comment)
|
||
logger.success(f"学生:{name} =>生成评语成功: {cleaned_text}")
|
||
return cleaned_text
|
||
except Exception as e:
|
||
logger.error(f"生成评语失败: {e}")
|
||
return "生成失败,请检查网络或Key。"
|