trying to get api working but it is not working yet
Browse files- experimental/clip_api_app.py +33 -33
- experimental/clip_api_app_client.py +35 -29
experimental/clip_api_app.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
-
from starlette.requests import Request
|
| 7 |
-
from PIL import Image
|
| 8 |
import ray
|
| 9 |
from ray import serve
|
|
|
|
| 10 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
| 11 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
| 12 |
|
|
@@ -24,11 +21,14 @@ class CLIPTransform:
|
|
| 24 |
|
| 25 |
print ("using device", self.device)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
with torch.no_grad():
|
| 30 |
prompt_embededdings = self.model.encode_text(text)
|
| 31 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
|
|
|
| 32 |
return(prompt_embededdings)
|
| 33 |
|
| 34 |
def image_to_embeddings(self, input_im):
|
|
@@ -45,31 +45,31 @@ class CLIPTransform:
|
|
| 45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
| 46 |
return(image_embeddings)
|
| 47 |
|
| 48 |
-
async def __call__(self, http_request: Request) -> str:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
deployment_graph = CLIPTransform.bind()
|
|
|
|
| 1 |
+
from typing import List
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
|
|
|
| 4 |
import ray
|
| 5 |
from ray import serve
|
| 6 |
+
from PIL import Image
|
| 7 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
| 8 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
| 9 |
|
|
|
|
| 21 |
|
| 22 |
print ("using device", self.device)
|
| 23 |
|
| 24 |
+
@serve.batch(max_batch_size=32)
|
| 25 |
+
# def text_to_embeddings(self, prompts: List[str]) -> torch.Tensor:
|
| 26 |
+
def text_to_embeddings(self, prompts: List[str]) -> List[np.ndarray]:
|
| 27 |
+
text = self.tokenizer(prompts).to(self.device)
|
| 28 |
with torch.no_grad():
|
| 29 |
prompt_embededdings = self.model.encode_text(text)
|
| 30 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
| 31 |
+
prompt_embededdings = prompt_embededdings.cpu().numpy().tolist()
|
| 32 |
return(prompt_embededdings)
|
| 33 |
|
| 34 |
def image_to_embeddings(self, input_im):
|
|
|
|
| 45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
| 46 |
return(image_embeddings)
|
| 47 |
|
| 48 |
+
# async def __call__(self, http_request: Request) -> str:
|
| 49 |
+
# request = await http_request.json()
|
| 50 |
+
# # print(type(request))
|
| 51 |
+
# # print(str(request))
|
| 52 |
+
# # switch based if we are using text or image
|
| 53 |
+
# embeddings = None
|
| 54 |
+
# if "text" in request:
|
| 55 |
+
# prompt = request["text"]
|
| 56 |
+
# embeddings = self.text_to_embeddings(prompt)
|
| 57 |
+
# elif "image" in request:
|
| 58 |
+
# image_url = request["image_url"]
|
| 59 |
+
# # download image from url
|
| 60 |
+
# import requests
|
| 61 |
+
# from io import BytesIO
|
| 62 |
+
# input_image = Image.open(BytesIO(image_url))
|
| 63 |
+
# input_image = input_image.convert('RGB')
|
| 64 |
+
# input_image = np.array(input_image)
|
| 65 |
+
# embeddings = self.image_to_embeddings(input_image)
|
| 66 |
+
# elif "preprocessed_image" in request:
|
| 67 |
+
# prepro = request["preprocessed_image"]
|
| 68 |
+
# # create torch tensor on the device
|
| 69 |
+
# prepro = torch.tensor(prepro).to(self.device)
|
| 70 |
+
# embeddings = self.preprocessed_image_to_emdeddings(prepro)
|
| 71 |
+
# else:
|
| 72 |
+
# raise Exception("Invalid request")
|
| 73 |
+
# return embeddings.cpu().numpy().tolist()
|
| 74 |
|
| 75 |
deployment_graph = CLIPTransform.bind()
|
experimental/clip_api_app_client.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import requests
|
| 6 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 7 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
|
| 10 |
english_text = (
|
|
@@ -12,38 +12,44 @@ english_text = (
|
|
| 12 |
"of wisdom, it was the age of foolishness, it was the epoch of belief"
|
| 13 |
)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
n_result, result = future.result()
|
| 28 |
-
result = json.loads(result)
|
| 29 |
-
print (f"{n_result} : {len(result[0])}")
|
| 30 |
-
|
| 31 |
-
# def process_text(numbers, max_workers=10):
|
| 32 |
-
# for n in numbers:
|
| 33 |
-
# n_result, result = send_text_request(n)
|
| 34 |
-
# result = json.loads(result)
|
| 35 |
-
# print (f"{n_result} : {len(result[0])}")
|
| 36 |
|
| 37 |
if __name__ == "__main__":
|
| 38 |
# n_calls = 100000
|
| 39 |
-
n_calls =
|
| 40 |
numbers = list(range(n_calls))
|
|
|
|
|
|
|
| 41 |
start_time = time.monotonic()
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
end_time = time.monotonic()
|
| 44 |
total_time = end_time - start_time
|
| 45 |
avg_time_ms = total_time / n_calls * 1000
|
| 46 |
calls_per_sec = n_calls / total_time
|
| 47 |
print(f"Average time taken: {avg_time_ms:.2f} ms")
|
| 48 |
print(f"Number of calls per second: {calls_per_sec:.2f}")
|
| 49 |
-
|
|
|
|
| 1 |
+
import ray
|
| 2 |
+
from ray import serve
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
|
| 6 |
+
# Create a Semaphore object
|
| 7 |
+
semaphore = asyncio.Semaphore(10)
|
| 8 |
|
| 9 |
test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
|
| 10 |
english_text = (
|
|
|
|
| 12 |
"of wisdom, it was the age of foolishness, it was the epoch of belief"
|
| 13 |
)
|
| 14 |
|
| 15 |
+
async def send_text_request(serve_client, number):
|
| 16 |
+
async with semaphore:
|
| 17 |
+
# async_handle = serve_client.get_handle("CLIPTransform", sync=False)
|
| 18 |
+
async_handle = serve.get_deployment("CLIPTransform").get_handle(sync=False)
|
| 19 |
+
# async_handle = serve.get_deployment("CLIPTransform").get_handle()
|
| 20 |
+
embeddings = ray.get(await async_handle.text_to_embeddings.remote(english_text))
|
| 21 |
+
# embeddings = await async_handle.text_to_embeddings.remote(english_text)
|
| 22 |
+
# embeddings = async_handle.text_to_embeddings.remote(english_text)
|
| 23 |
+
# embeddings = await ray.get(embeddings)
|
| 24 |
+
return number, embeddings
|
| 25 |
|
| 26 |
+
# def process_text(server_client, numbers, max_workers=10):
|
| 27 |
+
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 28 |
+
# futures = [executor.submit(send_text_request, server_client, number) for number in numbers]
|
| 29 |
+
# for future in as_completed(futures):
|
| 30 |
+
# n_result, result = future.result()
|
| 31 |
+
# print (f"{n_result} : {len(result[0])}")
|
| 32 |
+
async def process_text(server_client, numbers):
|
| 33 |
+
tasks = [send_text_request(server_client, number) for number in numbers]
|
| 34 |
+
for future in asyncio.as_completed(tasks):
|
| 35 |
+
n_result, result = await future
|
| 36 |
+
print (f"{n_result} : {len(result[0])}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
if __name__ == "__main__":
|
| 39 |
# n_calls = 100000
|
| 40 |
+
n_calls = 1
|
| 41 |
numbers = list(range(n_calls))
|
| 42 |
+
ray.init()
|
| 43 |
+
server_client = serve.start(detached=True)
|
| 44 |
start_time = time.monotonic()
|
| 45 |
+
|
| 46 |
+
# Run the async function
|
| 47 |
+
asyncio.run(process_text(server_client, numbers))
|
| 48 |
+
|
| 49 |
end_time = time.monotonic()
|
| 50 |
total_time = end_time - start_time
|
| 51 |
avg_time_ms = total_time / n_calls * 1000
|
| 52 |
calls_per_sec = n_calls / total_time
|
| 53 |
print(f"Average time taken: {avg_time_ms:.2f} ms")
|
| 54 |
print(f"Number of calls per second: {calls_per_sec:.2f}")
|
| 55 |
+
ray.shutdown()
|