| | |
| | """satellite_app.ipynb |
| | |
| | Automatically generated by Colab. |
| | |
| | Original file is located at |
| | https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
| | """ |
| |
|
| | import gradio as gr |
| | from safetensors.torch import load_model |
| | from timm import create_model |
| | from huggingface_hub import hf_hub_download |
| | from datasets import load_dataset |
| | import torch |
| | import torchvision.transforms as T |
| | import cv2 |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from PIL import Image |
| | import os |
| |
|
| | from langchain_community.document_loaders import TextLoader |
| | from langchain_community.vectorstores import FAISS |
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | from langchain.text_splitter import CharacterTextSplitter |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.runnables import RunnablePassthrough |
| | from langchain_fireworks import ChatFireworks |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from transformers import AutoModelForImageClassification, AutoImageProcessor |
| |
|
| |
|
| | safe_tensors = "model.safetensors" |
| |
|
| | model_name = 'swin_s3_base_224' |
| | |
| | model = create_model( |
| | model_name, |
| | num_classes=17 |
| | ) |
| |
|
| | load_model(model,safe_tensors) |
| |
|
| | def one_hot_decoding(labels): |
| | class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
| | id2label = {idx:c for idx,c in enumerate(class_names)} |
| |
|
| | id_list = [] |
| | for idx,i in enumerate(labels): |
| | if i == 1: |
| | id_list.append(idx) |
| |
|
| | true_labels = [] |
| | for i in id_list: |
| | true_labels.append(id2label[i]) |
| | return true_labels |
| |
|
| | def ragChain(): |
| | """ |
| | function: creates a rag chain |
| | output: rag chain |
| | """ |
| | loader = TextLoader("document.txt") |
| | docs = loader.load() |
| |
|
| | text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| | docs = text_splitter.split_documents(docs) |
| |
|
| | vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) |
| | retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) |
| |
|
| | api_key = os.getenv("FIREWORKS_API_KEY") |
| | llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) |
| |
|
| | prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | ( |
| | "system", |
| | """You are a knowledgeable landscape deforestation analyst. |
| | """ |
| | ), |
| | ( |
| | "human", |
| | """First mention the detected labels only with short description. |
| | Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. |
| | Don't include conversational messages. |
| | """, |
| | ), |
| |
|
| | ("human", "{context}, {question}"), |
| | ] |
| | ) |
| |
|
| | rag_chain = ( |
| | { |
| | "context": retriever, |
| | "question": RunnablePassthrough() |
| | } |
| | | prompt |
| | | llm |
| | | StrOutputParser() |
| | ) |
| |
|
| | return rag_chain |
| |
|
| | def model_output(image): |
| |
|
| | PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
| |
|
| | img_size = (224,224) |
| | test_tfms = T.Compose([ |
| | T.Resize(img_size), |
| | T.ToTensor(), |
| | ]) |
| |
|
| | img = test_tfms(PIL_image) |
| |
|
| | with torch.no_grad(): |
| | logits = model(img.unsqueeze(0)) |
| |
|
| | predictions = logits.sigmoid() > 0.5 |
| | predictions = predictions.float().numpy().flatten() |
| | pred_labels = one_hot_decoding(predictions) |
| | output_text = " ".join(pred_labels) |
| |
|
| | query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." |
| |
|
| | return query |
| |
|
| | def generate_response(rag_chain, query): |
| | """ |
| | input: rag chain, query |
| | function: generates response using llm and knowledge base |
| | output: generated response by the llm |
| | """ |
| | return rag_chain.invoke(f"{query}") |
| | |
| | def main(image): |
| | query = model_output(image) |
| | chain = ragChain() |
| | output = generate_response(chain, query) |
| | return output |
| | title = "Satellite Image Landscape Analysis for Deforestation" |
| | description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc." |
| | app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, |
| | description=description, |
| | examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]]) |
| | app.launch(share = True) |
| |
|
| |
|