Sunday, June 16, 2024

One-Shot LLM + RAG with Knowledge Graph

 Background.

One-Shot. 

One-shot learning refers to the model's ability to perform a task by seeing only one example. The model uses this single example as a reference to understand the task better and generate the appropriate response.

Application in LLM RAG.

  • Single Example Provided: The prompt includes a specific example that demonstrates how the task should be performed.
  • Learning from Example: The model uses this single example to learn how to apply the task's requirements to new inputs.

Video Tutorial.


Code.

import ollama
import chromadb

documents = [
"Bears are carnivoran mammals of the family Ursidae.",
"They are classified as caniforms, or doglike carnivorans.",
"Although only eight species of bears are extant, they are widespread, appearing in a wide variety of habitats throughout most of the Northern Hemisphere and partially in the Southern Hemisphere.",
"Bears are found on the continents of North America, South America, and Eurasia.",
"Common characteristics of modern bears include large bodies with stocky legs, long snouts, small rounded ears, shaggy hair, plantigrade paws with five nonretractile claws, and short tails.",
"With the exception of courting individuals and mothers with their young, bears are typically solitary animals.",
"They may be diurnal or nocturnal and have an excellent sense of smell.",
"Despite their heavy build and awkward gait, they are adept runners, climbers, and swimmers.",
"Bears use shelters, such as caves and logs, as their dens; most species occupy their dens during the winter for a long period of hibernation, up to 100 days.",
]
Kg_triplets = [["Bears", "are", "carnivoran mammals"],
["Bears", "belong to", "family Ursidae"],
["Bears", "are classified as", "caniforms"],
["Caniforms", "are", "doglike carnivorans"],
["Bears", "have", "eight species"],
["Bears", "are", "widespread"],
["Bears", "appear in", "a wide variety of habitats"],
["Bears", "are found in", "Northern Hemisphere"],
["Bears", "are partially found in", "Southern Hemisphere"],
["Bears", "are found on", "North America"],
["Bears", "are found on", "South America"],
["Bears", "are found on", "Eurasia"],
["Modern bears", "have", "large bodies"],
["Modern bears", "have", "stocky legs"],
["Modern bears", "have", "long snouts"],
["Modern bears", "have", "small rounded ears"],
["Modern bears", "have", "shaggy hair"],
["Modern bears", "have", "plantigrade paws with five nonretractile claws"],
["Modern bears", "have", "short tails"],
["Bears", "are typically", "solitary animals"],
["Bears", "can be", "diurnal"],
["Bears", "can be", "nocturnal"],
["Bears", "have", "an excellent sense of smell"],
["Bears", "are", "adept runners"],
["Bears", "are", "adept climbers"],
["Bears", "are", "adept swimmers"],
["Bears", "use", "shelters"],
["Shelters", "include", "caves"],
["Shelters", "include", "logs"],
["Bears", "use", "shelters as dens"],
["Most species", "occupy", "dens during winter"],
["Bears", "hibernate for", "up to 100 days"],]
# Convert the triplets to text
def triplet_to_text(triplet):
txt = str(triplet[0]) +" "+str(triplet[1]) +" "+str(triplet[2])
# print(txt)
return txt
# triplet_texts = [triplet_to_text(triplet) for triplet in Kg_triplets]
# Create database
client = chromadb.PersistentClient(path="E:\\Niraj_Work\\DL_Projects\\llm_projects\\database_tmp")
collection = client.create_collection(name="bear_kg_one_shot")
metadata = {"hnsw:space":"cosine"}
# store each document in a vector embedding database

for d in range(0,len(Kg_triplets)):
triplet_txt = triplet_to_text(Kg_triplets[d])
response = ollama.embeddings(model="mxbai-embed-large", prompt=triplet_txt)
embedding = response["embedding"]
collection.add(
ids=[str(d)],
embeddings=[embedding],
documents=[triplet_txt]
)

# an example prompt
prompt = "How Polar bear's body looks?"
one_shot_context = "The polar bear is a large bear native to the Arctic and nearby areas."
augmented_text = one_shot_context+"\n"+prompt
# generate an embedding for the prompt and retrieve the most relevant doc
response = ollama.embeddings(
prompt=augmented_text,
model="mxbai-embed-large"
)
results = collection.query(
query_embeddings=[response["embedding"]],
n_results=4
)
print("result = ",results)
# print(collection.get(include=['embeddings','documents','metadatas']))

# data = results['documents'][0][0]
data = ""
supported_docs = results['documents']
if len(supported_docs)==1:
data = results['documents'][0][0]
else:
for i in range(0, len(supported_docs)):
data = data+" "+str(supported_docs[i])
data = data.strip()
# generate a response combining the prompt and data we retrieved in step 2
output = ollama.generate(
model="llama3",
prompt=f"Using this data: {data}. Respond to this prompt: {augmented_text}"
)

print(output['response'])

No comments:

Post a Comment