| | import streamlit as st |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import torch |
| |
|
| | |
| | model_name = "dejanseo/Intent-XS" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| | model.eval() |
| |
|
| | |
| | label_map = { |
| | 1: 'Commercial', |
| | 2: 'Non-Commercial', |
| | 3: 'Branded', |
| | 4: 'Non-Branded', |
| | 5: 'Informational', |
| | 6: 'Navigational', |
| | 7: 'Transactional', |
| | 8: 'Commercial Investigation', |
| | 9: 'Local', |
| | 10: 'Entertainment' |
| | } |
| |
|
| | |
| | def get_predictions(text): |
| | inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits = outputs.logits |
| | probabilities = torch.sigmoid(logits).squeeze() |
| | predictions = (probabilities > 0.5).int() |
| | return probabilities.numpy(), predictions.numpy() |
| |
|
| | |
| | st.title('Multi-label Classification with Intent-XS') |
| | query = st.text_input("Enter your query:") |
| |
|
| | if st.button('Submit'): |
| | if query: |
| | probabilities, predictions = get_predictions(query) |
| | result = {label_map[i+1]: f"Probability: {prob:.2f}" for i, prob in enumerate(probabilities) if predictions[i] == 1} |
| | if result: |
| | st.write("Predicted Categories:") |
| | for label, prob in result.items(): |
| | st.write(f"{label}: {prob}") |
| | else: |
| | st.write("No relevant categories predicted.") |
| | else: |
| | st.write("Please enter a query to get predictions.") |
| |
|