GPT-NeoXT-Chat-Base-20B foundation model for chatbot applications is now available on Amazon SageMaker

Today we are excited to announce that Together Computer’s GPT-NeoXT-Chat-Base-20B language foundation model is available for customers using Amazon SageMaker JumpStart. GPT-NeoXT-Chat-Base-20B is an open-source model to build conversational bots. You can easily try out this model and use it with JumpStart. JumpStart is the machine learning (ML) hub of Amazon SageMaker that provides access to foundation models in addition to built-in algorithms and end-to-end solution templates to help you quickly get started with ML.

In this post, we walk through how to deploy the GPT-NeoXT-Chat-Base-20B model and invoke the model within an OpenChatKit interactive shell. This demonstration provides an open-source foundation model chatbot for use within your application.

JumpStart models use Deep Java Serving that uses the Deep Java Library (DJL) with deep speed libraries to optimize models and minimize latency for inference. The underlying implementation in JumpStart follows an implementation that is similar to the following notebook. As a JumpStart model hub customer, you get improved performance without having to maintain the model script outside of the SageMaker SDK. JumpStart models also achieve improved security posture with endpoints that enable network isolation.

Foundation models in SageMaker

JumpStart provides access to a range of models from popular model hubs, including Hugging Face, PyTorch Hub, and TensorFlow Hub, which you can use within your ML development workflow in SageMaker. Recent advances in ML have given rise to a new class of models known as foundation models, which are typically trained on billions of parameters and are adaptable to a wide category of use cases, such as text summarization, generating digital art, and language translation. Because these models are expensive to train, customers want to use existing pre-trained foundation models and fine-tune them as needed, rather than train these models themselves. SageMaker provides a curated list of models that you can choose from on the SageMaker console.

You can now find foundation models from different model providers within JumpStart, enabling you to get started with foundation models quickly. You can find foundation models based on different tasks or model providers, and easily review model characteristics and usage terms. You can also try out these models using a test UI widget. When you want to use a foundation model at scale, you can do so easily without leaving SageMaker by using pre-built notebooks from model providers. Because the models are hosted and deployed on AWS, you can rest assured that your data, whether used for evaluating or using the model at scale, is never shared with third parties.

GPT-NeoXT-Chat-Base-20B foundation model

Together Computer developed GPT-NeoXT-Chat-Base-20B, a 20-billion-parameter language model, fine-tuned from ElutherAI’s GPT-NeoX model with over 40 million instructions, focusing on dialog-style interactions. Additionally, the model is tuned on several tasks, such as question answering, classification, extraction, and summarization. The model is based on the OIG-43M dataset that was created in collaboration with LAION and Ontocord.

In addition to the aforementioned fine-tuning, GPT-NeoXT-Chat-Base-20B-v0.16 has also undergone further fine-tuning via a small amount of feedback data. This allows the model to better adapt to human preferences in the conversations. GPT-NeoXT-Chat-Base-20B is designed for use in chatbot applications and may not perform well for other use cases outside of its intended scope. Together, Ontocord and LAION collaborated to release OpenChatKit, an open-source alternative to ChatGPT with a comparable set of capabilities. OpenChatKit was launched under an Apache-2.0 license, granting complete access to the source code, model weights, and training datasets. There are several tasks that OpenChatKit excels at out of the box. This includes summarization tasks, extraction tasks that allow extracting structured information from unstructured documents, and classification tasks to classify a sentence or paragraph into different categories.

Let’s explore how we can use the GPT-NeoXT-Chat-Base-20B model in JumpStart.

Solution overview

You can find the code showing the deployment of GPT-NeoXT-Chat-Base-20B on SageMaker and an example of how to use the deployed model in a conversational manner using the command shell in the following GitHub notebook.

In the following sections, we expand each step in detail to deploy the model and then use it to solve different tasks:

  1. Set up prerequisites.
  2. Select a pre-trained model.
  3. Retrieve artifacts and deploy an endpoint.
  4. Query the endpoint and parse a response.
  5. Use an OpenChatKit shell to interact with your deployed endpoint.

Set up prerequisites

This notebook was tested on an ml.t3.medium instance in Amazon SageMaker Studio with the Python 3 (Data Science) kernel and in a SageMaker notebook instance with the conda_python3 kernel.

Before you run the notebook, use the following command to complete some initial steps required for setup:

%pip install --upgrade sagemaker –quiet

Select a pre-trained model

We set up a SageMaker session like usual using Boto3 and then select the model ID that we want to deploy:

model_id, model_version = "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16", "*"

Retrieve artifacts and deploy an endpoint

With SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the instance_type, image_uri, and model_uri for the pre-trained model. To host the pre-trained model, we create an instance of sagemaker.model.Model and deploy it. The following code uses ml.g5.24xlarge for the inference endpoint. The deploy method may take a few minutes.

endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

# Retrieve the inference instance type for the specified model.
instance_type = instance_types.retrieve_default(
    model_id=model_id, model_version=model_version, scope="inference"
)

# Retrieve the inference docker container uri.
image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=instance_type,
)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

# Create the SageMaker model instance. The inference script is prepacked with the model artifact.
model = Model(
    image_uri=image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# Set the serializer/deserializer used to run inference through the sagemaker API.
serializer = JSONSerializer()
deserializer = JSONDeserializer()

# Deploy the Model.
predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
    serializer=serializer,
    deserializer=deserializer
)

Query the endpoint and parse the response

Next, we show you an example of how to invoke an endpoint with a subset of the hyperparameters:

payload = {
    "text_inputs": "<human>: Tell me the steps to make a pizzan<bot>:",
    "max_length": 500,
    "max_time": 50,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "stopping_criteria": ["<human>"],
}
response = predictor.predict(payload)
print(response[0][0]["generated_text"])

The following is the response that we get:

<human>: Tell me the steps to make a pizza
<bot>: 1. Choose your desired crust, such as thin-crust or deep-dish. 
2. Preheat the oven to the desired temperature. 
3. Spread sauce, such as tomato or garlic, over the crust. 
4. Add your desired topping, such as pepperoni, mushrooms, or olives. 
5. Add your favorite cheese, such as mozzarella, Parmesan, or Asiago. 
6. Bake the pizza according to the recipe instructions. 
7. Allow the pizza to cool slightly before slicing and serving.
<human>:

Here, we have provided the payload argument "stopping_criteria": ["<human>"], which has resulted in the model response ending with the generation of the word sequence <human>. The JumpStart model script will accept any list of strings as desired stop words, convert this list to a valid stopping_criteria keyword argument to the transformers generate API, and stop text generation when the output sequence contains any specified stop words. This is useful for two reasons: first, inference time is reduced because the endpoint doesn’t continue to generate undesired text beyond the stop words, and second, this prevents the OpenChatKit model from hallucinating additional human and bot responses until other stop criteria are met.

Use an OpenChatKit shell to interact with your deployed endpoint

OpenChatKit provides a command line shell to interact with the chatbot. In this step, you create a version of this shell that can interact with your deployed endpoint. We provide a bare-bones simplification of the inference scripts in this OpenChatKit repository that can interact with our deployed SageMaker endpoint.

There are two main components to this:

  • A shell interpreter (JumpStartOpenChatKitShell) that allows for iterative inference invocations of the model endpoint
  • A conversation object (Conversation) that stores previous human/chatbot interactions locally within the interactive shell and appropriately formats past conversations for future inference context

The Conversation object is imported as is from the OpenChatKit repository. The following code creates a custom shell interpreter that can interact with your endpoint. This is a simplified version of the OpenChatKit implementation. We encourage you to explore the OpenChatKit repository to see how you can use more in-depth features, such as token streaming, moderation models, and retrieval augmented generation, within this context. The context of this notebook focuses on demonstrating a minimal viable chatbot with a JumpStart endpoint; you can add complexity as needed from here.

A short demo to showcase the JumpStartOpenChatKitShell is shown in the following video.

The following snippet shows how the code works:

class JumpStartOpenChatKitShell(cmd.Cmd):
    intro = (
        "Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to "
        "list commands. For example, type /quit to exit shell.n"
    )
    prompt = ">>> "
    human_id = "<human>"
    bot_id = "<bot>"
    
    def __init__(self, predictor: Predictor, cmd_queue: Optional[List[str]] = None, **kwargs):
        super().__init__()
        self.predictor = predictor
        self.payload_kwargs = kwargs
        self.payload_kwargs["stopping_criteria"] = [self.human_id]
        if cmd_queue is not None:
            self.cmdqueue = cmd_queue

    def preloop(self):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def precmd(self, line):
        command = line[1:] if line.startswith('/') else 'say ' + line
        return command

    def do_say(self, arg):
        self.conversation.push_human_turn(arg)
        prompt = self.conversation.get_raw_prompt()
        payload = {"text_inputs": prompt, **self.payload_kwargs}
        response = self.predictor.predict(payload)
        output = response[0][0]["generated_text"][len(prompt):]
        self.conversation.push_model_response(output)
        print(self.conversation.get_last_turn())

    def do_reset(self, arg):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def do_hyperparameters(self, arg):
        print(f"Hyperparameters: {self.payload_kwargs}n")

    def do_quit(self, arg):
        return True

You can now launch this shell as a command loop. This will repeatedly issue a prompt, accept input, parse the input command, and dispatch actions. Because the resulting shell may be utilized in an infinite loop, this notebook provides a default command queue (cmdqueue) as a queued list of input lines. Because the last input is the command /quit, the shell will exit upon exhaustion of the queue. To dynamically interact with this chatbot, remove the cmdqueue.

cmd_queue = [
    "Hello!",
]
JumpStartOpenChatKitShell(
    endpoint_name=endpoint_name,
    cmd_queue=cmd_queue,
    max_new_tokens=128,
    do_sample=True,
    temperature=0.6,
    top_k=40,
).cmdloop()

Example 1: Conversation context is retained

The following prompt shows that the chatbot is able to retain the context of the conversation to answer follow-up questions:

Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to list commands. For example, type /quit to exit shell.

<<<  Hello! How may I help you today? 

>>>  What is the capital of US?

<<<  The capital of US is Washington, D.C.

>>>  How far it is from PA ?

<<<  It is approximately 1100 miles.

Example 2: Classification of sentiments

In the following example, the chatbot performed a classification task by identifying the sentiments of the sentence. As you can see, the chatbot was able to classify positive and negative sentiments successfully.

Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to list commands. For example, type /quit to exit shell.

<<<  Hello! How may I help you today?

>>>  What is the sentiment of this sentence "The weather is good and I am going to play outside, it is sunny and warm"

<<<  POSITIVE

>>>  What is the sentiment of this sentence " The news this morning was tragic and it created lot of fear and concerns in city"

<<<  NEGATIVE

Example 3: Summarization tasks

Next, we tried summarization tasks with the chatbot shell. The following example shows how the long text about Amazon Comprehend was summarized to one sentence and the chatbot was able to answer follow-up questions on the text:

Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to list commands. For example, type /quit to exit shell.

<<<  Hello! How may I help you today?

>>>  Write a short summary for this text: Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. nYou can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. nAll of the Amazon Comprehend features accept UTF-8 text documents as the input. In addition, custom classification and custom entity recognition accept image files, PDF files, and Word files as input. nAmazon Comprehend can examine and analyze documents in a variety of languages, depending on the specific feature. For more information, see Languages supported in Amazon Comprehend. Amazon Comprehend's Dominant language capability can examine documents and determine the dominant language for a far wider selection of languages.

<<<  Amazon Comprehend is a natural language processing (NLP) service that uses deep learning and other advanced AI techniques to extract insights from unstructured text.

>>>  What kind of files does custom entity recognition accept as inputs?

<<<  Image files, PDF files, and Word files

Example 4: Extract structured information from unstructured text

In the following example, we used the chatbot to create a markdown table with headers, rows, and columns to create a project plan using the information that is provided in free-form language:

Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to list commands. For example, type /quit to exit shell.

<<<  Hello! How may I help you today?

>>>  Generate a table summarizing the options outlined in this email.   Team, we need to plan a project.   The first task is for creating a web app and will take 3 developers, 2 testers with duration of 3 weeks. This is priority 1 The second task is for refactoring a web app and will take 2 developers, 5 testers with duration of 4 weeks. This is priority 2  A markdown table with 2 rows and six columns: (1) Task ID , (2) Task Description, (3) Developers, (4) Testers, (5) Duration, (6) Priority

<<<  | Task ID | Task Description | Developers | Testers | Duration | Priority |
| --------- | --------- | --------- | --------- | --------- | --------- |
| 1 | Create a web app | 3 | 2 | 3 weeks | 1 |
| 2 | Refactor a web app | 2 | 5 | 4 weeks | 2 |

Example 5: Commands as input to chatbot

We can also provide input as commands like /hyperparameters to see hyperparameters values and /quit to quit the command shell:

>>>  /hyperparameters

<<<  Hyperparameters: {'max_new_tokens': 128, 'do_sample': True, 'temperature': 0.6, 'top_k': 40, 'stopping_criteria': ['<human>']}

>>>  /quit

These examples showcased just some of the tasks that OpenChatKit excels at. We encourage you to try various prompts and see what works best for your use case.

Clean up

After you have tested the endpoint, make sure you delete the SageMaker inference endpoint and the model to avoid incurring charges.

Conclusion

In this post, we showed you how to test and use the GPT-NeoXT-Chat-Base-20B model using SageMaker and build interesting chatbot applications. Try out the foundation model in SageMaker today and let us know your feedback!

This guidance is for informational purposes only. You should still perform your own independent assessment, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the third-party model referenced in this guidance. AWS has no control or authority over the third-party model referenced in this guidance, and does not make any representations or warranties that the third-party model is secure, virus-free, operational, or compatible with your production environment and standards. AWS does not make any representations, warranties or guarantees that any information in this guidance will result in a particular outcome or result.


About the authors

Rachna Chadha is a Principal Solutions Architect AI/ML in Strategic Accounts at AWS. Rachna is an optimist who believes that the ethical and responsible use of AI can improve society in the future and bring economic and social prosperity. In her spare time, Rachna likes spending time with her family, hiking, and listening to music.

Dr. Kyle Ulrich is an Applied Scientist with the Amazon SageMaker built-in algorithms team. His research interests include scalable machine learning algorithms, computer vision, time series, Bayesian non-parametrics, and Gaussian processes. His PhD is from Duke University and he has published papers in NeurIPS, Cell, and Neuron.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He got his PhD from University of Illinois Urbana-Champaign. He is an active researcher in machine learning and statistical inference, and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.