MrKustic commited on
Commit
dddf864
·
verified ·
1 Parent(s): a504c78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py CHANGED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ MODEL_NAME = "t-bank-ai/RuDialoGPT-small"
6
+
7
+ print("Загружаем модель и токенизатор...")
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
+
11
+ # Если в Spaces доступен GPU, переводим модель на него
12
+ device = 0 if torch.cuda.is_available() else -1
13
+ if device == 0:
14
+ model = model.to("cuda")
15
+
16
+ model.eval()
17
+
18
+ def chat(user_input):
19
+ # Дополнительно можно добавить обозначение конца строки для корректного завершения генерации
20
+ input_with_eos = user_input + tokenizer.eos_token
21
+
22
+ # Токенизируем входной текст
23
+ inputs = tokenizer.encode(input_with_eos, return_tensors="pt")
24
+ if device >= 0:
25
+ inputs = inputs.to("cuda")
26
+
27
+ # Генерация ответа
28
+ outputs = model.generate(
29
+ inputs,
30
+ max_length=200, # можно изменить длину генерируемого текста
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ do_sample=True,
33
+ top_p=0.9,
34
+ temperature=0.7
35
+ )
36
+ # Декодируем сгенерированный текст
37
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
+ # Если модель возвращает и исходный текст, можно убрать его:
40
+ if generated_text.startswith(user_input):
41
+ generated_text = generated_text[len(user_input):].strip()
42
+ return generated_text
43
+
44
+ # Создаем интерфейс Gradio
45
+ iface = gr.Interface(
46
+ fn=chat,
47
+ inputs=gr.Textbox(lines=2, placeholder="Например: Привет, как дела?", label="Введите сообщение"),
48
+ outputs=gr.Textbox(label="Ответ модели"),
49
+ title="RuDialoGPT-small Chat",
50
+ description="Диалоговый чат на базе модели t-bank-ai/RuDialoGPT-small от Hugging Face"
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ iface.launch()