philipobiorah commited on
Commit
211e5c4
·
verified ·
1 Parent(s): b010b3f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -5
main.py CHANGED
@@ -10,27 +10,34 @@ import matplotlib.pyplot as plt
10
  import base64
11
  from io import BytesIO
12
 
13
- # Ensure Matplotlib and Transformers use writable cache directories
14
- os.environ["MPLCONFIGDIR"] = "/home/user/.cache/matplotlib"
15
- os.makedirs("/home/user/.cache/matplotlib", exist_ok=True)
 
 
 
 
16
 
17
  app = Flask(__name__)
18
 
19
  # Load Model from Local Directory
20
  MODEL_PATH = "bert_imdb_model.bin"
21
- TOKENIZER_PATH = "bert-base-uncased" # Using default tokenizer from transformers
22
 
23
  if os.path.exists(MODEL_PATH):
24
  print("Loading model from local file...")
25
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
26
  model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
27
  else:
28
- print(f"Error: Model file {MODEL_PATH} not found. Make sure it's in the same directory.")
29
  exit(1)
30
 
31
  model.eval()
32
  tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH)
33
 
 
 
 
34
  def predict_sentiment(text):
35
  tokens = tokenizer.encode(text, add_special_tokens=True)
36
  chunks = [tokens[i:i + 512] for i in range(0, len(tokens), 512)]
 
10
  import base64
11
  from io import BytesIO
12
 
13
+ # Set writable cache directories within /tmp
14
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Replaces TRANSFORMERS_CACHE
15
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
16
+
17
+ # Create directories if they don't exist
18
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
19
+ os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
20
 
21
  app = Flask(__name__)
22
 
23
  # Load Model from Local Directory
24
  MODEL_PATH = "bert_imdb_model.bin"
25
+ TOKENIZER_PATH = "bert-base-uncased"
26
 
27
  if os.path.exists(MODEL_PATH):
28
  print("Loading model from local file...")
29
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
30
  model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
31
  else:
32
+ print(f"Error: Model file {MODEL_PATH} not found.")
33
  exit(1)
34
 
35
  model.eval()
36
  tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH)
37
 
38
+ # ... rest of your code (keep the rest unchanged) ...
39
+
40
+
41
  def predict_sentiment(text):
42
  tokens = tokenizer.encode(text, add_special_tokens=True)
43
  chunks = [tokens[i:i + 512] for i in range(0, len(tokens), 512)]