Sunday, June 16, 2024

Zero-Shot LLM-RAG With Knowledge Graph

 Background.

Zero-Shot. 

Zero-shot learning refers to the model's ability to perform a task without having seen any explicit examples of that task during training or in the prompt. Instead, the model relies on its pre-existing knowledge and understanding of language to generate the desired output.

Application in LLM RAG.

  • No Examples Provided: When a task is given to the model, it receives only the task description or query without any specific examples to guide its response.
  • Generalization: The model generalizes from the task description to understand what is required and generates a response based on its training data.

Video Tutorial.


Working Code

import ollama
import chromadb

documents = [
"Bears are carnivoran mammals of the family Ursidae.",
"They are classified as caniforms, or doglike carnivorans.",
"Although only eight species of bears are extant, they are widespread, appearing in a wide variety of habitats throughout most of the Northern Hemisphere and partially in the Southern Hemisphere.",
"Bears are found on the continents of North America, South America, and Eurasia.",
"Common characteristics of modern bears include large bodies with stocky legs, long snouts, small rounded ears, shaggy hair, plantigrade paws with five nonretractile claws, and short tails.",
"With the exception of courting individuals and mothers with their young, bears are typically solitary animals.",
"They may be diurnal or nocturnal and have an excellent sense of smell.",
"Despite their heavy build and awkward gait, they are adept runners, climbers, and swimmers.",
"Bears use shelters, such as caves and logs, as their dens; most species occupy their dens during the winter for a long period of hibernation, up to 100 days.",
]
Kg_triplets = [["Bears", "are", "carnivoran mammals"],
["Bears", "belong to", "family Ursidae"],
["Bears", "are classified as", "caniforms"],
["Caniforms", "are", "doglike carnivorans"],
["Bears", "have", "eight species"],
["Bears", "are", "widespread"],
["Bears", "appear in", "a wide variety of habitats"],
["Bears", "are found in", "Northern Hemisphere"],
["Bears", "are partially found in", "Southern Hemisphere"],
["Bears", "are found on", "North America"],
["Bears", "are found on", "South America"],
["Bears", "are found on", "Eurasia"],
["Modern bears", "have", "large bodies"],
["Modern bears", "have", "stocky legs"],
["Modern bears", "have", "long snouts"],
["Modern bears", "have", "small rounded ears"],
["Modern bears", "have", "shaggy hair"],
["Modern bears", "have", "plantigrade paws with five nonretractile claws"],
["Modern bears", "have", "short tails"],
["Bears", "are typically", "solitary animals"],
["Bears", "can be", "diurnal"],
["Bears", "can be", "nocturnal"],
["Bears", "have", "an excellent sense of smell"],
["Bears", "are", "adept runners"],
["Bears", "are", "adept climbers"],
["Bears", "are", "adept swimmers"],
["Bears", "use", "shelters"],
["Shelters", "include", "caves"],
["Shelters", "include", "logs"],
["Bears", "use", "shelters as dens"],
["Most species", "occupy", "dens during winter"],
["Bears", "hibernate for", "up to 100 days"],]
# Convert the triplets to text
def triplet_to_text(triplet):
txt = str(triplet[0]) +" "+str(triplet[1]) +" "+str(triplet[2])
# print(txt)
return txt
# triplet_texts = [triplet_to_text(triplet) for triplet in Kg_triplets]
# Create database
client = chromadb.PersistentClient(path="E:\\Niraj_Work\\DL_Projects\\llm_projects\\database_tmp")
collection = client.create_collection(name="bear_kg")
metadata = {"hnsw:space":"cosine"}
# store each document in a vector embedding database

for d in range(0,len(Kg_triplets)):
triplet_txt = triplet_to_text(Kg_triplets[d])
response = ollama.embeddings(model="mxbai-embed-large", prompt=triplet_txt)
embedding = response["embedding"]
collection.add(
ids=[str(d)],
embeddings=[embedding],
documents=[triplet_txt]
)

# an example prompt
prompt = "How does the bear's body looks?"

# generate an embedding for the prompt and retrieve the most relevant doc
response = ollama.embeddings(
prompt=prompt,
model="mxbai-embed-large"
)
results = collection.query(
query_embeddings=[response["embedding"]],
n_results=3
)
print("result = ",results)
print(collection.get(include=['embeddings','documents','metadatas']))

# data = results['documents'][0][0]
data = ""
supported_docs = results['documents']
if len(supported_docs)==1:
data = results['documents'][0][0]
else:
for i in range(0, len(supported_docs)):
data = data+" "+str(supported_docs[i])
data = data.strip()
# generate a response combining the prompt and data we retrieved in step 2
output = ollama.generate(
model="llama3",
prompt=f"Using this data: {data}. Respond to this prompt: {prompt}"
)

print(output['response'])

No comments:

Post a Comment