<![CDATA[Ashwin Mathur]]>https://awinml.github.io/https://awinml.github.io/favicon.pngAshwin Mathurhttps://awinml.github.io/Ghost 5.49Fri, 19 Jan 2024 19:54:39 GMT60<![CDATA[Run LLMs on Your CPU with Llama.cpp: A Step-by-Step Guide]]>https://awinml.github.io/llm-ggml-python/649f3b71e27a34833e9d3930Tue, 31 Oct 2023 20:38:00 GMTRun LLMs on Your CPU with Llama.cpp: A Step-by-Step Guide

Large language models (LLMs) are becoming increasingly popular, but they can be computationally expensive to run. There have been several advancements like the support for 4-bit and 8-bit loading of models on HuggingFace. But they require a GPU to work. This has limited their use to people with access to specialized hardware, such as GPUs. Even though it is possible to run these LLMs on CPUs, the performance is limited and hence restricts the usage of these models.

Recent work by Georgi Gerganov has made it possible to run LLMs on CPUs with high performance. This is thanks to his implementation of the llama.cpp library, which provides high-speed inference for a variety of LLMs.

The original llama.cpp library focuses on running the models locally in a shell. This does not offer a lot of flexibility to the user and makes it hard for the user to leverage the vast range of python libraries to build applications. Recently LLM frameworks like LangChain have added support for llama.cpp using the llama-cpp-python package.

In this blog post, we will see how to use the llama.cpp library in Python using the llama-cpp-python package. This package provides Python bindings for llama.cpp, which makes it easy to use the library in Python.

We will also see how to use the llama-cpp-python library to run the Zephyr  LLM, which is an open-source model based on the Mistral model.

Set up llama-cpp-python

Setting up the python bindings is as simple as running the following command:

pip install llama-cpp-python

For more detailed installation instructions, please see the llama-cpp-python documentation: https://github.com/abetlen/llama-cpp-python#installation-from-pypi-recommended.

Using a LLM with llama-cpp-python

Once you have installed the llama-cpp-python package, you can start using it to run LLMs.

You can use any language model with llama.cpp provided that it has been converted to the GGML format. There are already GGML versions available for most popular LLMs and the required GGML can be easily found on HuggingFace.

An important thing to note is that the original LLMs have been quantized when converting them to GGML format. This helps reduce the memory requirement for running these large models, without a significant loss in performance. For example, this helps us load a 7 billion parameter model of size 13GB in less than 4GB of RAM.

In this article we use the GGUF version of Zephyr-7B-Beta which is available on the HuggingFace Hub.

The model can be downloaded from here: https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF.

Downloading the GGUF file and Loading the LLM

The following code can be used to download the model. The code downloads the required GGML file, in this case the zephyr-7b-beta.Q4_0 GGUF, from the Hugging Face Hub. The code also checks if the file is already present before attempting to download it.

import os
import urllib.request


def download_file(file_link, filename):
    # Checks if the file already exists before downloading
    if not os.path.isfile(filename):
        urllib.request.urlretrieve(file_link, filename)
        print("File downloaded successfully.")
    else:
        print("File already exists.")

# Dowloading GGML model from HuggingFace
ggml_model_path = "https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF/resolve/main/zephyr-7b-beta.Q4_0.gguf"
filename = "zephyr-7b-beta.Q4_0.gguf"

download_file(ggml_model_path, filename)

The next step is to load the model that you want to use. This can be done using the following code:

from llama_cpp import Llama

llm = Llama(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=512, n_batch=126)

There are two important parameters that should be set when loading the model.

  • n_ctx: This is used to set the maximum context size of the model. The default value is 512 tokens.

The context size is the sum of the number of tokens in the input prompt and the max number of tokens that can be generated by the model. A model with smaller context size generates text much quicker than a model with a larger context size. If the use case does not demand very long generations or prompts, it is better to reduce the context length for better performance.

The number of tokens in the prompt and generated text can be checked using the free Tokenizer tool by OpenAI.

  • n_batch: This is used to set the maximum number of prompt tokens to batch together when generating the text. The default value is 512 tokens.

The n_batch parameter should be set carefully. Lowering the n_batch helps speed up text generation over multithreaded CPUs. Reducing it too much may cause the text generation to deteriorate significantly.

The complete list of parameters can be viewed here: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama

Generating Text using the LLM

The following code writes a simple wrapper function to generate text using the LLM.

def generate_text(
    prompt="Who is the CEO of Apple?",
    max_tokens=256,
    temperature=0.1,
    top_p=0.5,
    echo=False,
    stop=["#"],
):
    output = llm(
        prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        echo=echo,
        stop=stop,
    )
    output_text = output["choices"][0]["text"].strip()
    return output_text


def generate_prompt_from_template(input):
    chat_prompt_template = f"""<|im_start|>system
You are a helpful chatbot.<|im_end|>
<|im_start|>user
{input}<|im_end|>"""
    return chat_prompt_template


prompt = generate_prompt_from_template(
    "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
)

generate_text(
    prompt,
    max_tokens=356,
)

The llm object has several important parameters that are used while generating text:

  • prompt: The input prompt to the model. This text is tokenized and passed to the model.

  • max_tokens: The parameter is used to set the maximum number of tokens the model can generate. This parameter controls the length of text generation. Default value is 128 tokens.

  • temperature: The token sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. Default value is 1.

  • top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.

  • echo: Boolean parameter to control whether the model returns (echoes) the model prompt at the beginning of the generated text.

  • stop: A list of strings that is used to stop text generation. If the model encounters any of the strings, the text generation will be stopped at that token. Used to control model hallucination and prevent the model from generating unnecessary text.

The llm object returns a dictionary object of the form:

{
  "id": "xxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",  # text generation id 
  "object": "text_completion",              # object name
  "created": 1679561337,                    # time stamp
  "model": "./models/7B/zephyr-7b-model.gguf",    # model path
  "choices": [
    {
      "text": "Q: Name the planets in the solar system? A: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune and Pluto.", # generated text
      "index": 0,
      "logprobs": None,
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 14,       # Number of tokens present in the prompt
    "completion_tokens": 28,   # Number of tokens present in the generated text
    "total_tokens": 42
  }
}

The generated text can be easily extracted from the dictionary object using output["choices"][0]["text"].

Example text generation using Zephyr-7B

import os
import urllib.request
from llama_cpp import Llama


def download_file(file_link, filename):
    # Checks if the file already exists before downloading
    if not os.path.isfile(filename):
        urllib.request.urlretrieve(file_link, filename)
        print("File downloaded successfully.")
    else:
        print("File already exists.")


# Dowloading GGML model from HuggingFace
ggml_model_path = "https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF/resolve/main/zephyr-7b-beta.Q4_0.gguf"
filename = "zephyr-7b-beta.Q4_0.gguf"

download_file(ggml_model_path, filename)


llm = Llama(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=512, n_batch=126)


def generate_text(
    prompt="Who is the CEO of Apple?",
    max_tokens=256,
    temperature=0.1,
    top_p=0.5,
    echo=False,
    stop=["#"],
):
    output = llm(
        prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        echo=echo,
        stop=stop,
    )
    output_text = output["choices"][0]["text"].strip()
    return output_text


def generate_prompt_from_template(input):
    chat_prompt_template = f"""<|im_start|>system
You are a helpful chatbot.<|im_end|>
<|im_start|>user
{input}<|im_end|>"""
    return chat_prompt_template


prompt = generate_prompt_from_template(
    "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
)

generate_text(
    prompt,
    max_tokens=356,
)

Generated text:

As the sun began to set over the Pacific Ocean, I found myself standing on the shores of Waikiki Beach in Honolulu, Hawaii. The vibrant colors of the sky painted a breathtaking scene that left me speechless. This was just the beginning of my unforgettable journey through the Aloha State.

Hawaii is a place like no other, where the lush green mountains meet the crystal-clear waters of the ocean. The culture and traditions of Hawaii are deeply rooted in its people, and I was eager to immerse myself in this unique experience.

One of my first stops was the historic Pearl Harbor. As an American, it was a humbling experience to learn about the events that took place here during World War II. The USS Arizona Memorial is a powerful tribute to the men and women who lost their lives during the attack on December 7, 1941.

Next, I headed to the North Shore of Oahu, where I was greeted by the stunning views of Turtle Bay Resort. Here, I had the opportunity to learn about Hawaiian culture through traditional activities such as lei making and ukulele lessons. The locals were incredibly welcoming and eager to share their heritage with me.

One of my favorite experiences in Hawaii was attending a traditional Hawaiian luau. The feast was filled with delicious local cuisine, including poke (raw fish), kalua pig, and poi (a staple food made from taro root). The entertainment included hula dancing, fire knife dancing, and other cultural performances that left me in awe.

The notebook with the example can be viewed here.
The complete code for running the examples can be found on GitHub.

Conclusion

In this blog post, we explored how to use the llama.cpp library in Python with the llama-cpp-python package. These tools enable high-performance CPU-based execution of LLMs. llama.cpp is updated almost every day. The speed of inference is getting better, and the community regularly adds support for new models. You can also convert your own Pytorch language models into the GGUF format. llama.cpp has a “convert.py” that will do that for you.

The llama.cpp library and llama-cpp-python package provide robust solutions for running LLMs efficiently on CPUs. If you're interested in incorporating LLMs into your applications, I recommend exploring these resources.

]]>
<![CDATA[No More Paid Endpoints: How to Create Your Own Free Text Generation Endpoints with Ease]]>https://awinml.github.io/llm-text-gen-api/649ff0cb174d833a229a9a26Sat, 01 Jul 2023 10:14:41 GMT

Large language models (LLMs) are gaining popularity because of their capacity to produce text, translate between languages and produce various forms of creative content. However, one of the biggest challenges of using LLMs is the cost of accessing them. Many LLMs, such as OpenAI's GPT-3, are only available through paid APIs.

Luckily, there is a smart way to use any LLM for free. By deploying your own LLM on an API endpoint, you can access it from anywhere in the world without having to pay any fees. In this article, we will show you how to deploy any open-source LLM as a free API endpoint using HuggingFace and Gradio.

Benefits of Creating Your Own Text Generation Endpoints

  • It can save you money. Paid APIs can be expensive, especially if you are using a large number of requests. By deploying your own LLM, you can avoid these costs.
  • Control over your data. When you use a paid API, you are giving the API provider access to your data. By deploying your own endpoint, you can keep your data safe and secure.
  • Access to the latest models. By deploying your own endpoint, you can choose the LLM you wish to use.
  • Ability to use the LLM capabilities on any device. LLMs require significant resources to run. The API endpoint enables any device connected to the internet to harness the capabilities of the LLM.

Why use Gradio and HuggingFace Spaces?

While there are popular cloud hosting providers like AWS and GCP, their setup process can be complex, and you often need to build your own Flask API. Furthermore, these providers lack free tiers that can handle large language models (LLMs).

Gradio is a tool that makes it easy to create interactive web apps that can be used to interact with LLMs. Huggingface Spaces is a free hosting service that allows you to deploy your machine learning apps to the web.

With the help of a Gradio app's API functionality, we can easily access the Language Model (LLM). We deploy the Gradio app using the free tier of HuggingFace Spaces.

Before we can get started on how to deploy the LLMs, let's create a new space on HuggingFace.

Creating a new Space on HuggingFace

A "Space" on HuggingFace is a hosting environment that can be used to host your ML app. Spaces are priced based on CPU type, and the simplest one is free!

Create a new Space by:

  • Go to https://huggingface.co/spaces and click Create new Space.
    (You will need to sign-up for a HuggingFace Account to create the space.)
  • Select the MIT license if you’re unsure.
  • Select Gradio as Space SDK.
  • Select Public since you want the API endpoint to be available at all times.
No More Paid Endpoints: How to Create Your Own Free Text Generation Endpoints with Ease
Creating a new Space on HuggingFace | Image by Author

Creating the Gradio app to access the LLM

In this article, we create two Gradio apps to access two types of LLM formats:

  • A LLM checkpoint available on HuggingFace (the usual PyTorch model)
  • A CPU-optimized version of the LLM (GGML format based on LLaMA.cpp)

The basic format of the app is the same for both formats:

  1. Load the model.
  2. Create a function that accepts an input prompt and uses the model to return the generated text.
  3. Make a Gradio interface to display the generated text and accept user input.

LLM from a HuggingFace Checkpoint:

In this example we deploy the newly launched Falcon model using its HuggingFace checkpoint.

To create the Gradio app, make a new file called app.py, and add the following code.

app.py

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "tiiuae/falcon-7b-instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")


def generate_text(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    attention_mask = torch.ones(input_ids.shape)

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=200,
        do_sample=True,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
    )

    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(output_text)

    # Remove Prompt Echo from Generated Text
    cleaned_output_text = output_text.replace(input_text, "")
    return cleaned_output_text


text_generation_interface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.inputs.Textbox(label="Input Text"),
    ],
    outputs=gr.inputs.Textbox(label="Generated Text"),
    title="Falcon-7B Instruct",
).launch()

This Python script uses a HuggingFace Transformers library to load the tiiuae/falcon-7b-instruct model. The max generation length is set to 200 tokens and the top_k sampling of tokens is set to 10. These text generation parameters can be set as per your requirement. The prompt is removed from the generated text so that the model only returns the generated text and not the prompt plus the generated text.

A requirements.txt file is created to specify the dependencies for the app. The following libraries are included in the file:

requirements.txt

datasets
transformers
accelerate
einops
safetensors

The complete example can be viewed at: https://huggingface.co/spaces/awinml/falcon-7b-instruct-api.

The code for the app can be downloaded from: https://huggingface.co/spaces/awinml/falcon-7b-instruct-api/tree/main.

LLM from a CPU-Optimized (GGML) format:

LLaMA.cpp is a C++ library that provides a high-performance inference engine for large language models (LLMs). It is based on the GGML (Graph Neural Network Machine Learning) library, which provides a fast and efficient way to represent and process graphs. LLAMA.cpp uses GGML to efficiently load and run LLMs, making it possible to run quick inference on large models.

In this example we load the Vicuna model in GGML format and deploy it for inference. The inference time is significantly lower as compared to the model checkpoint available on HuggingFace.

To create the Gradio app, make a new file called app.py, and add the following code.

app.py

import os
import urllib.request
import gradio as gr
from llama_cpp import Llama


def download_file(file_link, filename):
    # Checks if the file already exists before downloading
    if not os.path.isfile(filename):
        urllib.request.urlretrieve(file_link, filename)
        print("File downloaded successfully.")
    else:
        print("File already exists.")


# Dowloading GGML model from HuggingFace
ggml_model_path = "https://huggingface.co/CRD716/ggml-vicuna-1.1-quantized/resolve/main/ggml-vicuna-7b-1.1-q4_1.bin"
filename = "ggml-vicuna-7b-1.1-q4_1.bin"

download_file(ggml_model_path, filename)


llm = Llama(model_path=filename, n_ctx=512, n_batch=126)


def generate_text(prompt="Who is the CEO of Apple?"):
    output = llm(
        prompt,
        max_tokens=256,
        temperature=0.1,
        top_p=0.5,
        echo=False,
        stop=["#"],
    )
    output_text = output["choices"][0]["text"].strip()

    # Remove Prompt Echo from Generated Text
    cleaned_output_text = output_text.replace(prompt, "")
    return cleaned_output_text


description = "Vicuna-7B"

examples = [
    ["What is the capital of France?", "The capital of France is Paris."],
    [
        "Who wrote the novel 'Pride and Prejudice'?",
        "The novel 'Pride and Prejudice' was written by Jane Austen.",
    ],
    ["What is the square root of 64?", "The square root of 64 is 8."],
]

gradio_interface = gr.Interface(
    fn=generate_text,
    inputs="text",
    outputs="text",
    examples=examples,
    title="Vicuna-7B",
)
gradio_interface.launch()

The app first downloads the required GGML file, in this case the Vicuna-7b-Q4.1 GGML. The code checks if the file is already present before attempting to download it.

We leverage the python bindings for LLaMA.cpp to load the model.

The context length of the model is set to 512 tokens. The maximum supported context length for the Vicuna model is 2048 tokens. A model with a smaller context length generates text much faster than a model with a larger context length. In most cases, a smaller context length is sufficient.

The number of tokens in the prompt and generated text can be checked using the free Tokenizer tool by OpenAI

The batch size is set to 128 tokens. This helps speed up text generation over multithreaded CPUs.

The max generation length is set to 256 tokens, temperature to 0.1, and top-p sampling of tokens to 0.5. A list of tokens to stop generation is also added. These text generation parameters can be set as per your requirement.

A detailed guide on how to use GGML versions of popular open-source LLMs for fast inference can be found at How to Run LLMs on Your CPU with Llama.cpp: A Step-by-Step Guide.

A requirements.txt file is created to specify the dependencies for the app. The following libraries are included in the file:

requirements.txt

llama-cpp-python==0.1.62

The complete example can be viewed at: https://huggingface.co/spaces/awinml/vicuna-7b-ggml-api.

The code for the app can be downloaded from: https://huggingface.co/spaces/awinml/vicuna-7b-ggml-api/tree/main.

Deploying the Gradio app on HuggingFace Spaces:

Deploying a Gradio app on HuggingFace Spaces is as simple as uploading the following files on your HuggingFace Space:

  • app.py - This file contains the code of the app.
  • requirements.txt - This file lists the dependencies for the app.
No More Paid Endpoints: How to Create Your Own Free Text Generation Endpoints with Ease
Upload the files for the Gradio app on HuggingFace Spaces | Image by Author

The deployed app will expect you to pass in the input text or prompt, which it’ll then use to generate an appropriate response.

No More Paid Endpoints: How to Create Your Own Free Text Generation Endpoints with Ease
Gradio app with the Vicuna model | Image by Author

Accessing the LLM as an API Endpoint:

The deployed Gradio app is already running a Prediction (Inference) API endpoint in the background.
The endpoint can be easily accessed through the Gradio Python Client.

At the bottom of the deployed app, you will see a link called "Use via API". Click this link to view the instructions on how to call your app with the API.

No More Paid Endpoints: How to Create Your Own Free Text Generation Endpoints with Ease
Using the API Endpoints to access the LLM | Image by Author

To use the API, you will need to install the Gradio Client Python library. You can do this by running the following command in your terminal:

pip install gradio_client

Once you have installed the library, you can use any of the deployed apps for generating text similar to the OpenAI completion endpoints in the following manner:

from gradio_client import Client

# Pass the link to your HuggingFace Space here
client = Client("https://awinml-falcon-7b-instruct-api.hf.space/")

# Pass the Input Prompt to the model
result = client.predict(
    "What is the capital of USA?",
    api_name="/predict"
)
print(result)

This code will first create a Client object and pass the link to your HuggingFace Space to it. Then, it will pass the input prompt to the model and call the predict() method. The predict() method will return the generated text, which you can then print to the console.

Latency Analysis:

We measure the time taken by the endpoints to generate text based on different prompts. The testing was done on two private endpoints created based on the instructions in the post.

The Falcon-7B model has been loaded using torch.bfloat16 weights using HuggingFace Transformers. The Vicuna-7B model has been loaded using the Q4.1 GGML file using Llama.cpp. The model has been loaded with a context length of 512 tokens instead of the full 2048 tokens.

For both the models, the max generation tokens was set to 256 tokens.The temperature, top_p and top_k values were chosen based on the article and were not the same for both the models. This post is not to compare the text generation performance of both the models.

The following prompts were used to test the latencies of the endpoints:

  • What is the capital of Greece?
  • Can you help me write a formal email to a potential business partner proposing a joint venture?
  • What are the two main differences between Python and JavaScript programming languages?
  • Extract five important keywords from the following text: How can governments utilize fiscal and monetary policies to combat economic recessions?

Vicuna-7B:

  • Question: What is the capital of Greece?
    Generated Text: The capital of Greece is Athens.
    Time taken: 6.1s.

  • Question: Can you help me write a formal email to a potential business partner proposing a joint venture?
    Generated Text:
    I am writing to express my interest in forming a joint venture with your company. I believe that our complementary strengths and shared vision for the future make this an ideal opportunity for both of us. Our company has extensive experience in [insert relevant industry or field], and we have developed a unique product line that is well-positioned to meet the needs of our customers. We are confident that your company's expertise in [insert relevant area] would be a valuable asset to our partnership, and we believe that together we could create something truly special. I am particularly excited about the potential for collaboration on [insert specific project or initiative], which I believe has the potential to revolutionize the industry. Our teams have already begun discussing the details of this project, and we are confident that it will be a huge success. In addition to our shared goals and complementary strengths, I am also impressed by your company's commitment to [insert relevant social or environmental issue]. As someone who shares these values, I believe that our partnership could make a real difference in the world.

    Time Taken: 126.2s

  • Question: What are the two main differences between Python and JavaScript programming languages?
    Generated Text:
    Python is a general-purpose, high-level language that is easy to learn and use. It has a syntax that is similar to English and supports multiple data types such as integers, strings, and lists. JavaScript, on the other hand, is a client-side scripting language used for web development. It is primarily used to add interactive elements to websites.
    Python is an interpreted language while JavaScript is compiled. Python code is executed by an interpreter, which converts the code into machine code that can be run directly by the computer's processor. In contrast, JavaScript code is first compiled into bytecode and then run by a JavaScript engine.

    Time Taken: 72.4s

  • Question: Make a list of the top three Morgan Freeman movies. Only specify the movie names.
    Generated Text:
    1. The Shawshank Redemption (1994)
    2. Driving Miss Daisy (1989)
    3. Glory (1989)
    Time Taken: 16.1s

Falcon-7B:

  • Question: What is the capital of Greece?
    Generated Text: Athens is the capital of Greece.
    Time taken: 61.2s.

  • Question: Can you help me write a formal email to a potential business partner proposing a joint venture?
    Generated Text:
    --------- Did not capture generation. Took too long! ----------

    Time Taken: 2650s+

  • Question: What are the two main differences between Python and JavaScript programming languages?
    Generated Text:
    1. Syntactical differences: Python is a higher-level language and has a more structured syntax compared to JavaScript, which is a lower-level language with less strict syntax rules.
    2. Use of brackets for function calls: While Python uses the 'def' keyword to define functions, JavaScript functions are defined using brackets or named function literals.
    Time Taken: 1524.3s

  • Question: Make a list of the top three Morgan Freeman movies. Only specify the movie names.
    Generated Text:
    1. The Shawshank Redemption
    2. Forrest Gump
    3. Invincible
    Time Taken: 371.3s

Verdict:

The Falcon-7B model is extremely slow to generate text, with the lowest latency being 61 seconds and the highest being over 2,650 seconds. This is clearly not very useful.

I suspect that this is because HuggingFace recently added support for Falcon to Transformers and has not yet optimized the model implementation. This could be because the architecture of the model is still largely unknown. The HuggingFace implementation also relies on external code to run the model, which may be a bottleneck. We may have better luck with other, more mature generative models on HuggingFace.

The Vicuna GGML model, on the other hand, seems to perform extremely well, with latencies ranging from a mere 6 seconds to 126 seconds for the longest generation.

Llama.cpp is being constantly improved, and using a smaller, quantized version may be able to reduce this latency even further. The LLM loading parameters also significantly affect the performance, so optimizing those may also lead to some speedup.

The good thing is that HuggingFace provides no restrictions on the number of spaces that a user can create. This means that multiple spaces can be created and easily used to process requests in parallel.

Based on this, it is quite easy to recommend creating endpoints using the Vicuna GGML model and using it for prototyping applications instead of the expensive OpenAI GPT-3 API.

Conclusion

Now, you can deploy any large language model (LLM) as an API endpoint with just a few lines of code, thanks to Gradio and HuggingFace Spaces. These tools make it simple to build your own free text generation endpoints. By deploying your own LLM on an API endpoint, you can save money by avoiding costly paid APIs while still benefiting from the remarkable capabilities of these powerful language models.

]]>
<![CDATA[Financial Dashboard for Market Intelligence]]>https://awinml.github.io/financial-dashboard-for-market-intelligence/64790593edcff63646236d5eThu, 01 Jun 2023 20:56:49 GMT
  • Financial Dashboard for Market Intelligence

    Built a end-to-end financial dashboard that collects and consolidates all of a business's critical observations in one place using the information obtained from the annual 10-K SEC Filings of 12 companies.

  • Collected text data from 10-K filings from SEC EDGAR using the SEC ExtractorAPI.

  • The filings of 12 companies spanning 5 sectors were collected for the duration of 2017 to 2021. Each filing had over 34,000 words.

  • The data was cleaned and transformed for Sentiment Analysis and Summarization. The data was manually labelled for both the tasks.

  • The RoBERTa, FinBERT and DistilBERT models were fine-tuned for sentiment analysis. The best results were obtained using the fine-tuned DistilBERT model. It achieved an Accuracy of 91.11% and an ROC-AUC Score of 0.972.

  • The T5, DistilPEGASUS and DistilBART models were fine-tuned for summarization. The best results were obtained using the fine-tuned DistilBART model. It achieved an ROUGE-L Score of 67.7%.

  • RAKE NLTK was used to identify important keywords from the generated summaries.

  • The Financial Dashboard was deployed as a web-app using Streamlit. It contains:

    • Insights and summaries for different sections from annual corporate filings.
    • Identification of important keywords mentioned in the report.
    • Sentiment-based score that measures the company's performance over a certain time period.

The app can be viewed here: Financial Dashboard

Motivation

In the current data driven world, it is essential to have access to the right information for impactful decision making. All publicly listed companies have to file annual reports to the government. These consolidated statements allow investors, financial analysts, business owners and other interested parties to get a complete overview of the company. Companies all over the world make key financial decisions based on annually released public filings.

These corporate filings are rife with complicated legal and financial jargon and make it practically impossible for a layman to understand. In most cases these documents have to be manually read and decoded by people with expert financial and legal understanding. The goal of this project is to develop a tool that automates this tedious procedure and makes it easier to acquire crucial financial information.

Data

To extract the text from the SEC filing, the SEC’s ExtractorAPI was used. The API can extract any text section from 10-Q, 10-K, and 8-K SEC filings, and returns the extracted content in cleaned and standardized text or HTML format.
The twelve companies for which the data has been collected as listed below organized by sector:

  1. Pharmaceutical:
    Abbvie, Pfizer, Merck
  2. Technology:
    Alphabet, Meta, Microsoft
  3. Retail:
    Costco
  4. Oil and Natural Gas:
    Chevron
  5. Food and Beverages:
    Coca Cola, Pepsico

Snapshot of the data:

Financial Dashboard for Market Intelligence
Snapshot of the data | Image by Author

Sentiment Analysis

A local cross validation split was created by randomly sampling rows from the records of 12 companies across sectors like Technology, Finance, Retail and Pharma. A sample 10k report for Meta can be viewed here.

The RoBERTa, FinBERT and DistilBERT models were fine-tuned for sentiment analysis. The best results were obtained using the fine-tuned DistilBERT model. It achieved an Accuracy of 91.11% and an ROC-AUC Score of 0.972.

Model Accuracy F1 AUC
Roberta 0.662 0.656 0.628
FinBERT 0.746 0.682 0.721
DistilBERT 0.911 0.914 0.972

Summarization

For the summarization task, the data of Pfizer, Costco and Meta was labeled and used. A local cross validation split was created by randomly sampling rows from the records of these companies.
Text summarization was carried out using these three transformers models:

The T5, DistilPEGASUS and DistilBART models were fine-tuned for summarization. The best results were obtained using the fine-tuned DistilBART model. It achieved an ROUGE-L Score of 67.7%.

Model ROUGUE-1 ROUGUE-2 ROUGUE-L ROUGUE-LSUM
T5 32.22 28.5 31.5 31.5
DistilPEGASUS 48.32 34.48 43.51 31.50
DistilBART 72.28 61.15 67.70 71

Identifying Important Keywords

RAKE NLTK was used to identify important keywords from the generated summaries.

Code:

The code to run the project can be found here: Financial Dashboard for Market Intelligence - Github.

]]>
<![CDATA[American Express - Default Prediction]]>https://awinml.github.io/american-express-default-prediction/6479043fedcff63646236d20Thu, 01 Jun 2023 20:53:59 GMT
  • American Express - Default Prediction

    Built a classification model to predict the probability that a customer does not pay back their credit card balance (defaults) based on their monthly customer statements using the data provided by American Express.

  • The data was particularly challenging to deal with as it had 5.5 million records and 191 anonymized features. 122 features had more than 10% missing values. The target variable had severe class imbalance.

  • Engineered new features by taking different aggregations over time which helped increase model accuracy by 12%.

  • Optimized XGBoost and LightGBM Classifiers using RandomSearchCV to reach the best model.

  • A Soft-Voting Ensemble of the best performing XGBoost and LightGBM Models was used to make final predictions which yielded an Accuracy of 94.48%, an F1-Score of 96.71% and an ROC-AUC Score of 96.40%.

Data

Credit default prediction is central to managing risk in a consumer lending business. Credit default prediction allows lenders to optimize lending decisions, which leads to a better customer experience and sound business economics.

The dataset contains profile features for each customer at each statement date. Features are anonymized and normalized, and fall into the following general categories:

D_* = Delinquency variables
S_* = Spend variables
P_* = Payment variables
B_* = Balance variables
R_* = Risk variables

with the following features being categorical:

'B_30', 'B_38', 'D_114', 'D_116', 'D_117', 'D_120', 'D_126', 'D_63', 'D_64', 'D_66', 'D_68'

The dataset can be downloaded from here.

Analysis

The complete analysis can be viewed here.

Target Distribution

  • In the data present we observe that 25.9% of records have defaulted on their credit card payments whereas 74.1% have paid their bills on time.
  • This distribution shows us that there is severe class imbalance present.
American Express - Default Prediction

Distribution of Number of Defaults per day for the first Month:

The proportion of customers that defualt is consistent across each day in the data, with a slight weekly seasonal trend influenced by the day when the customers receive their statements.

American Express - Default Prediction

Frequency of Customer Statements for the first month:

American Express - Default Prediction
  • There is weekly seasonal pattern observed in the number of statements received per day.
  • As seen above this trend does not seem to be significantly affecting the proportion of default.

Distribution of values of Payment Variables:

American Express - Default Prediction
  • We notice that Payment 2 is heavily negatively skewed (left skewed).
  • Even though Payment 4 have continuous values between 0 and 1, most of the density is clustered around 0 and 1.
  • This tells us that there may be some Gaussian Noise present. The noise can be removed and into a binary variable.

Correlation of Features with Target Variable:

  • Payment 2 is negatively correlated with the target with a correlation of -0.67.
  • Delinquency 48 is positively correlated with the target with a correlation of 0.61.

Correlation of Payment Variables with Target

American Express - Default Prediction
  • We observe that Payment 2 and Target are highly negatively correlated.
  • This could be probably be due to the fact that people paying their bill have a less chance of default.

Experiments:

  • The dataset presents a significant challenge with a substantial number of missing values, making imputation impossible due to the anonymization of features and the lack of a clear rationale behind imputation. This unique constraint compels us to select models that are capable of handling missing values, pushing us to explore advanced techniques.

  • One notable characteristic of the dataset is its high cardinality, boasting an impressive 191 features. However, the presence of missing values poses limitations on the utilization of conventional dimensionality reduction techniques like Principal Component Analysis (PCA) and feature selection methods such as Recursive Feature Elimination (RFE). Thus, we must seek alternative approaches to tackle this challenge.

  • In order to overcome the limitations imposed by missing values, we employ a creative strategy: engineering new features through aggregations over the time dimension. By disregarding missing values during the aggregation process, we generate dense and informative engineered features that can be effectively utilized for modeling purposes.

  • Several prominent classification models that gracefully handle missing values have been considered for this task, including XGBoost, LightGBM, and CatBoost. Notably, these models internally incorporate imputation techniques, dynamically adapting them based on the approach that yields the greatest performance improvement. This ensures that missing values do not impede the model's ability to make accurate predictions.

  • To establish a performance baseline, an XGBoost model with default hyperparameters was trained, yielding an impressive accuracy of 78.84%, an F1-Score of 54.64%, and an ROC-AUC Score of 65.72%. These results serve as a solid starting point for further improvement.

  • Subsequently, the LightGBM model with default hyperparameters was employed, resulting in notable enhancements across performance metrics. Specifically, the LightGBM model boosted the accuracy by 1%, the F1-Score by 12%, and the ROC-AUC Score by 6%, further solidifying its effectiveness.

  • To fine-tune the XGBoost and LightGBM models, a Randomized Grid Search was conducted utilizing 5-fold cross-validation. This comprehensive search approach allowed us to explore a wide range of hyperparameter combinations efficiently.

  • Hyperparameters of the XGBoost model, such as n_estimators, max_depth, and learning_rate, were meticulously tuned, resulting in remarkable improvements. The accuracy was enhanced by 9%, the F1-Score by 18%, and the ROC-AUC Score by 3%, showcasing the model's capacity to capture more nuanced patterns in the data.

  • Similarly, the hyperparameters of the LightGBM model, including n_estimators, feature_fraction, and learning_rate, were fine-tuned through the Randomized Grid Search. This meticulous optimization process led to a marginal but meaningful accuracy improvement of 0.1%, an F1-Score boost of 6%, and an impressive 10% enhancement in the ROC-AUC Score.

By utilizing advanced techniques, engineering informative features, and fine-tuning the models through extensive hyperparameter optimization, we were able to elevate the performance of both the XGBoost and LightGBM models while gracefully handling the challenges imposed by missing values and high feature cardinality.

Results:

A Soft Voting Classifier was used to create a ensemble of both the models and was used for generating the final predictions. It achieved an Accuracy of 94.48%, an F1-Score of 96.71% and an ROC-AUC Score of 96.40%.

The results from all the models have been summarized below:

Model Accuracy F1-Score ROC-AUC Score
XGBoost (default) 78.84 54.64 65.72
LightGBM (default) 79.84 62.92 71.86
XGBoost (fine-tuned) 88.61 80.74 74.96
LightGBM (fine-tuned) 88.72 86.42 84.22
Voting Classifier (XGB + LGBM) 94.48 96.72 96.40

Run Locally

The code to run the project can be found here: American Express - Default Prediction Github.

  1. Install required libraries:
  pip install -r requirements.txt
  1. Generate features:
  python amex-feature-engg.py
  1. Fine-tune models:
  python amex-fine-tuning.py
  1. Generate predictions:
  python amex-final-prediction.py

Feedback

If you have any feedback, please reach out to me.

]]>
<![CDATA[H&M Personalized Product Recommendations]]>https://awinml.github.io/h-m-personalized-product-recommendations/6478fc81edcff63646236cd6Thu, 01 Jun 2023 20:24:21 GMT
  • H&M Personalized Product Recommendations

    Built a product recommendation system to recommend products based on previous transactions, as well as from customer and product meta data using the data provided by H&M.

  • The data contains 1,05,542 unique products with information on 24 characteristics for each product.

  • The data contains information on 13,71,980 consumers and 3,17,88,324 client transactions from 2018 to 2020.

  • A custom lightweight candidate retrieval method was created using a combination of retrieval of candidates that were purchased together in the last week as well as
    most popular candidates based on age group.

  • The candidates were ranked using a LightGBM model based on features created using the frequency of product purchase as well as the percentage of customers that purchased that product.

  • A fine tuned recommendation system using a custom candidate retrieval method and LightGBM Ranking model was used to make final predictions which yielded an MAP@12 score of 0.345 and an overall AUC of 0.76.

Data

The purchase history of customers across time, along with supporting metadata has been provided. The goal is to predict what articles each customer will purchase in the 7-day period immediately after the training data ends.

Files provided:

  • articles.csv - detailed metadata for each article_id available for purchase
  • customers.csv - metadata for each customer_id in dataset
  • transactions.csv - data consisting of the purchases each customer for each date, as well as additional information. Duplicate rows correspond to multiple purchases of the same item.

The dataset can be downloaded from Kaggle.

Analysis

The complete analysis can be viewed in this notebook.

Distribution of number of Transactions per day:

  • October 2019 recorded the highest number of transactions in duration of 2018 to 2020.
  • There is a quarterly seasonal spike of transactions.
  • There tends to be a large number of transactions in the month of December every year.
H&M Personalized Product Recommendations
Distribution of number of Transactions per day | Image by Author

Distribution of number of Transactions per day grouped by Sales Channel:

  • Sales Channel 1 has daily consistent number of transactions per day with rarely any large spikes.
  • Sales Channel 2 consistently outperforms Sales Channel 1 throughout 2018 to 2020.
  • The quarterly seasonal spike of transactions is caused by transactions through Sales Channel 2.
H&M Personalized Product Recommendations
Distribution of number of Transactions per day grouped by Sales Channel | Image by Author

Distribution of number of unique Articles sold per day grouped by Sales Channel:

  • Sales Channel 1 has daily consistent number of unique Articles sold per day with rarely any spikes.
  • Sales Channel 2 consistently sells more unique products per day than Sales Channel 1 throughout 2018 to 2020.
H&M Personalized Product Recommendations
Distribution of number of unique Articles sold per day grouped by Sales Channel | Image by Author

After seeing the distribution of transactions and unique articles sold per day, we get the intuition that Sales Channel 1 customers are more consistent and conservative buyers.

On the other hand, customers that use Sales Channel 2 are ready to try out new products and also purchase products only in specific months during the year.

Distribution of Customers across Age Group:

  • The highest number of customers are aged 21 years.
  • A large proportion of the customer demographic are young adults aged 19 to 26.
  • There is also a significant customer base that is aged 46 to 56 years.
H&M Personalized Product Recommendations
Distribution of Customers across Age Group | Image by Author

Distribution of Customers who have subscribed for Fashion News Alerts:

  • 67% of all the customers have not subscribed for Fashion News Alerts.
  • 32% of the customers have subscribed for regular updates.
  • 1% of the customers have subscribed for monthly updates
H&M Personalized Product Recommendations
Distribution of Customers who have subscribed for Fashion News Alerts | Image by Author

Product Groups with the highest number of Product Types:

The product group 'Accessories' has the highest number of product types followed by 'Shoes' and 'Upper Body Garments'.

H&M Personalized Product Recommendations
Product Groups with the highest number of Product Types | Image by Author

Product Types with the highest number of unique articles:

The product type 'Trousers' has the highest number of unique articles closely followed by 'Dress'.

H&M Personalized Product Recommendations
Product Types with the highest number of unique articles | Image by Author

Product Departments with the highest number of unique articles:

The product type 'Jersey' has the highest number of unique articles closely followed by 'Knitwear'.

H&M Personalized Product Recommendations
Product Departments with the highest number of unique articles | Image by Author

Product Graphical Appearance Names with the highest number of unique articles:

The highest number of articles are of 'Solid' appearance followed by 'All over pattern'.

H&M Personalized Product Recommendations
Product Graphical Appearance Names with the highest number of unique articles | Image by Author

Product Index with the highest number of unique articles:

The index named 'Ladieswear' has the highest number of unique articles closely followed by 'Divided'.

H&M Personalized Product Recommendations
Product Index with the highest number of unique articles | Image by Author

Product Colour Group Names with the highest number of unique articles:

The highest number of articles are of 'Black' colour group followed by 'Dark Blue' and 'White'.

H&M Personalized Product Recommendations
Product Colour Group Names with the highest number of unique articles | Image by Author

Experiments:

Intuition behind the custom retrieval strategy can be found here.

Candidate Retrieval

A lightweight candidate retrieval method was created that was a combination of the following retrieval strategies.

  • Recommend Items Purchased Together in the last week:

    • Get 5 pairs of each article that were sold in the past week.
    • Ignore any article that wasn't sold within the past week.
    • Ignored any pair purchased by less than 2 customers.
  • Recommend Items Purchased Together in the last few weeks:

    • The number of previous weeks was tuned.
  • Recommend most popular items based on age group

Experiments:

The custom retrieval strategy was developed based on careful consideration and intuitive understanding, which can is explored in detail in this notebook.

Candidate Retrieval

A sophisticated candidate retrieval method was devised by combining multiple effective retrieval strategies. The following approaches were employed:

Recommend Items Purchased Together in the Last Week:

  • Extracted 5 pairs of articles that were sold in the past week.
  • Excluded any article that was not sold within the past week.
  • Filtered out pairs purchased by fewer than 2 customers.

Recommend Items Purchased Together in the Last Few Weeks:

  • Adjusted the number of previous weeks based on empirical analysis and optimization.

Recommend Most Popular Items Based on Age Group:

  • Identified the most popular items considering the specific age group of the customers.

This candidate retrieval method is manually designed and time-aware, incorporating valuable trend information for enhanced performance.

Candidate Ranking:

To further refine the recommendation process, a comprehensive feature creation process was undertaken for candidate ranking. The features utilized are as follows:

Percentage of Customers the Pair was Based on: Calculated the proportion of customers who purchased the particular pair of items.

Recency of Article Purchase: Considered how recently the article was bought by customers.

Number of Times the Pair of Products was Purchased: Accounted for the frequency of purchases made for the given pair.

To rank the candidates, a powerful LightGBM ranking model was employed. The LightGBM model's hyperparameters, such as n_estimators and num_leaves, were meticulously tuned to ensure optimal performance.

Results:

The recommendation system, finely tuned using the custom candidate retrieval method and the advanced LightGBM Ranking model, delivered impressive results. The system achieved an outstanding MAP@12 score of 0.345, showcasing its ability to accurately recommend relevant items to users. Furthermore, the overall AUC (Area Under the Curve) reached an impressive value of 0.76, confirming the system's effectiveness in capturing user preferences and generating high-quality recommendations.

Run Locally

The code to run the project can be found here: H&M Personalized Product Recommendations Github.

  1. Install required libraries:
  pip install -r requirements.txt
  1. Generate local cv:
  python hm-cv.py
  1. Fine-tune models and generate predictions:
  python hm-custom-retrieval-pred.py.py

Feedback

If you have any feedback, please reach out to me.

]]>
<![CDATA[Jigsaw - Multilingual Toxic Comment Classification]]>https://awinml.github.io/jigsaw-multilingual-toxic-comment-classification/6478fb9eedcff63646236cbeThu, 01 Jun 2023 20:14:29 GMT
  • Built a multilingual text classification model to predict the probability that a comment is toxic using the data provided by Google Jigsaw.
  • The data had 4,35,775 text comments in 7 different languages.
  • A RNN model was used as a baseline. The BERT-Multilingual-base and XLMRoBERTa models were fine-tuned to get the best results.
  • The best results were obtained using the fine-tuned XLMRoberta model. It achieved an Accuracy of 96.24% and an ROC-AUC Score of 93.92%.
  • Data

    Jigsaw - Multilingual Toxic Comment Classification

    A single toxic comment can ruin an online discussion. Toxicity is anything that is rude, disrespectful, or likely to make someone leave the conversation. If we can identify these toxic contributions, we can create a safer and more collaborative internet.

    The goal is to find the probability that a comment is toxic. This can be done by using machine learning algorithms to analyze the text of a comment and identify words or phrases that are associated with toxicity. The algorithm can then calculate the probability that a comment is toxic based on the number of toxic words or phrases it contains.

    This information can then be used to flag toxic comments for review by human moderators. By identifying and removing toxic comments, we can help to create a more positive and productive online environment for everyone.

    Columns in the dataset:

    id - identifier within each file.
    comment_text - the text of the comment to be classified.
    lang - the language of the comment.
    toxic - whether or not the comment is classified as toxic.
    

    The comments are composed of multiple non-English languages and come either from Civil Comments or Wikipedia talk page edits.

    The dataset can be downloaded from Kaggle.

    Experiments:

    RNN:

    A baseline was created using the RNN model. An embedding layer of size 64 was used. The model was trained with an Adam optimizer, employing a learning rate of 0.001, over the course of 5 epochs. The RNN model exhibited a remarkable accuracy of 83.68%, showcasing its ability to effectively process sequential data. Furthermore, the model achieved an ROC-AUC Score of 55.72%, indicating its proficiency in distinguishing between positive and negative sentiment.

    BERT-Multilingual-base:

    To leverage the power of pre-trained transformer models, the BERT-Multilingual-base was fine-tuned on the provided dataset. With the addition of a hidden layer containing 1024 neurons, the model was primed for exceptional performance. Training was conducted using the Adam optimizer, utilizing a learning rate of 0.001 and weight decay of 1e-6, extending for 10 epochs. Impressively, the fine-tuned BERT-Multilingual-base model yielded an outstanding accuracy of 93.92%, exemplifying its capacity to grasp the nuances of multilingual sentiment analysis. Additionally, it attained an impressive ROC-AUC Score of 89.55%, substantiating its robustness in making precise predictions across diverse languages.

    XLM RoBERTa:

    To further elevate the accuracy and generalization of the sentiment analysis task, the XLMRoberta model was meticulously fine-tuned on the available dataset. Employing the AdamW optimizer with a learning rate of 1e-5 and weight decay of 1e-5, the model was meticulously trained over the span of 7 epochs. The XLMRoberta model showcased exceptional performance, achieving an extraordinary accuracy of 96.24%. This exemplifies its remarkable ability to capture subtle linguistic nuances and comprehend sentiment across various languages. Furthermore, the model garnered an outstanding ROC-AUC Score of 93.92%, underscoring its efficacy in discriminating between positive and negative sentiments with a high degree of confidence.

    For all the models that were fine-tuned:

    • A batch size of 64 was employed during the training process, ensuring efficient computation and effective utilization of computational resources.
    • Binary Cross-Entropy, a commonly adopted loss function for binary classification tasks, was utilized to train the models. It enabled accurate estimation of the dissimilarity between predicted and true sentiment labels.

    Results:

    Amongst the various models experimented with, the fine-tuned XLMRoberta model emerged as the clear winner, boasting the highest performance scores. Therefore, it was chosen as the ultimate choice for generating the final predictions. The fine-tuned XLMRoberta model demonstrated unparalleled accuracy, reaching an impressive 96.24%. Furthermore, it achieved a remarkable ROC-AUC Score of 93.92%. These extraordinary results solidify the model's status as the top performer, underscoring its immense capability to discern sentiment across multiple languages with exceptional precision and reliability.

    The results from all the models have been summarized below:

    Model Accuracy ROC-AUC Score
    RNN 83.68 55.72
    BERT-Multilingual-base (fine-tuned) 93.92 89.55
    XLM RoBERTa (fine-tuned) 96.24 93.92

    Run Locally

    The code to run the project can be found here: Jigsaw - Multilingual Toxic Comment Classification Github.

    1. Install required libraries:
      pip install -r requirements.txt
    
    1. Baseline model:
      python toxic-baseline-rnn.py
    
    1. Fine-tune models:
      python toxic-bertm-base.py
      python toxic-xlm-roberta.py
    

    Feedback

    If you have any feedback, please reach out to me.

    ]]>
    <![CDATA[Sports Image Classification]]>https://awinml.github.io/sports-image-classification/6478fa5eedcff63646236ca9Thu, 01 Jun 2023 20:10:03 GMTSports Image Classification

    Have you ever wondered how machines can recognize sports just by looking at images? Well, it's all thanks to the power of Convolutional Neural Networks (CNNs). In this article, we will explore how CNNs can be used to classify sports images and compare the performance of different CNN architectures.

    Dataset

    The dataset used for this project is a collection of images representing 100 different types of sports and activities. The sports range from traditional sports like "archery", "arm wrestling", "bowling", "football", "water polo", "weightlifting" to non-traditional ones like "wingsuit flying" and "nascar racing". The goal is to predict the correct sport based on the image.

    The dataset consists of 13572 train, 500 test, and 500 validation images. Each image is of size 224 x 224 pixels and has been segregated into train, test, and valid directories. The dataset can be downloaded from Kaggle.

    Experiments

    To classify the sports images, we compared the performance of five different CNN architectures: a custom CNN, InceptionV3, ResNet50V2, MobileNetV2, and EfficientNetB3. For each model, we fine-tuned the pre-trained ImageNet weights and added two hidden layers of 256 and 128 neurons respectively with leaky-relu activations. Dropout layers with p=0.1 were added to prevent overfitting. The number of epochs was 50 with early stopping with a patience parameter of 2 epochs. A batch size of 32 was used for training, and Sparse Categorical Cross-Entropy was used as the loss function.

    Custom CNN

    We started with a baseline custom CNN model with 3 convolution layers and 3 dense layers. A kernel of size 3 x 3 was used for all the convolution layers. Training the model with an Adam optimizer with a learning rate of 0.001 for 47 epochs yielded an Accuracy of 56.44%, F1-Score of 48.48%, and an ROC-AUC Score of 49.46%.

    InceptionV3

    The InceptionV3 model was initialized with pre-trained ImageNet weights. Only the Dense layers were fine-tuned. Training the model with an Adam optimizer with a learning rate of 0.001 for 22 epochs yielded an Accuracy of 68.92%, F1-Score of 64.92%, and an ROC-AUC Score of 66.64%.

    ResNet50V2

    The ResNet50V2 model was initialized with pre-trained ImageNet weights. Only the Dense layers were fine-tuned. Training the model with an Adam optimizer with a learning rate of 0.001 for 16 epochs yielded an Accuracy of 72.88%, F1-Score of 70.67%, and an ROC-AUC Score of 69.72%.

    MobileNetV2

    The MobileNetV2 model was initialized with pre-trained ImageNet weights, and all the layers were fine-tuned. Training the model with an Adam optimizer with a learning rate of 0.001 for 8 epochs yielded an Accuracy of 86.68%, F1-Score of 86.79%, and an ROC-AUC Score of 88.36%.

    EfficientNetB3

    The EfficientNetB3 model was initialized with pre-trained ImageNet weights, and all the layers were fine-tuned. Training the model with an Adam optimizer with a learning rate of 0.001 for 18 epochs yielded an Accuracy of 92.72%, F1-Score of 91.76%, and an ROC-AUC Score of 96.92%.

    Results

    The best results were obtained using a fine-tuned EfficientNetB3 model, which achieved an Accuracy of 92.72%, F1-Score of 91.76%, and an ROC-AUC Score of 96.92%. The results from all the models have been summarized in the table below:

    Model Accuracy F1-Score ROC-AUC Score
    Custom CNN 56.44 48.48 49.46
    InceptionV3 (fine-tuned) 68.92 64.92 66.64
    ResNet50V2 (fine-tuned) 72.88 70.67 69.72
    MobileNetV2 (fine-tuned) 86.68 86.79 88.36
    EfficientNetB3 (fine-tuned) 92.72 91.76 96.92

    Deploying the Model

    A web app was made using Streamlit to make predictions for new images using the best model. The live app can be viewed here.

    Run Locally

    All the code for this project can be found in this Github repository. To run the app locally, you can follow the instructions provided in the repository.

    1. Install required libraries:
        pip install -r streamlit/requirements.txt
      
    2. Fine-tune models:
        python sports-clf-custom-cnn.py
        python sports-clf-inception.py
        python sports-clf-resnet.py
        python sports-clf-mobilenet.py
        python sports-clf-efficientnet.py
      
    3. Generate predictions:
        python sports-clf-final-predictions.py
      

    Conclusion

    In conclusion, we have seen how CNNs can be used to classify sports images with high accuracy. We compared the performance of five different CNN architectures and found that the EfficientNetB3 model achieved the best results. This model can be used to classify sports images in real-time, which can be useful in various applications such as sports analytics, sports broadcasting, and sports betting.

    ]]>