TymaaHammouda's picture
Update app.py
b704903 verified
from fastapi import FastAPI
from huggingface_hub import hf_hub_download
import os
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModel
import json
print("Version ---- 1")
app = FastAPI()
from huggingface_hub import snapshot_download, hf_hub_download
import os
import shutil
BASE_DIR = os.path.expanduser("~/.sinatools")
# Paths expected by sinatools
RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar")
os.makedirs(BASE_DIR, exist_ok=True)
# -------------------------
# 1. Download relation model
# -------------------------
if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR):
snapshot_download(
repo_id="aaljabari/arabic-relation-extraction-model",
local_dir=RELATION_MODEL_DIR,
local_dir_use_symlinks=False
)
# -------------------------
# 2. Download NER resources
# -------------------------
if not os.path.exists(NER_DIR):
os.makedirs(NER_DIR, exist_ok=True)
nested_repo_path = snapshot_download(
repo_id="SinaLab/Nested"
)
# Copy tag_vocab.pkl to expected location
src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl")
dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl")
if os.path.exists(src_vocab):
shutil.copy(src_vocab, dst_vocab)
# Optional debug
print("sinatools dir:", os.listdir(BASE_DIR))
print("NER dir:", os.listdir(NER_DIR))
from sinatools.relations.relation_extractor import relation_extraction
from sinatools.relations.event_relation_extractor import event_argument_relation_extraction
class RelationRequest(BaseModel):
text: str
@app.post("/predict_relation")
def predict_relation(request: RelationRequest):
text = request.text
result = relation_extraction(text)
content = {"resp": result, "statusText": "OK", "statusCode": 0}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)
@app.post("/predict_event")
def predict_event(request: RelationRequest):
text = request.text
result = event_argument_relation_extraction(text)
content = {"resp": result, "statusText": "OK", "statusCode": 0}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)