Skip to content
  • David Berenstein's avatar
    [FEATURE] Feature/prepare for training feedbacktask (#3151) · c9bfaf28
    David Berenstein authored
    # Description
    
    I added a very rough outline of my ideation behind
    `prepare_for_training` with the new `FeedbackDataset`. As discussed
    there are 3 complexities:
    
    - How to resolve annotator alignment?
    - How to resolve optional fields, which have not been filled out? e.g.,
    "Please provide a correction for prompt 1?".
    - How handle potential concatenation of fields? 
    
    To make it modular I created a step-wise approach.
    
    1. `Pydantic` Models that map and verify data fields, like so. By doing
    this we keep the flexibility to allow for other tasks like
    TextClassification and this ensures we can directly use `datasets.field`
    and `dataset.questions` for defining training. We could also use the
    `name` values from the fields/questions, but this might be more error
    prone.
    2. `get_relevant_data_for_training()` in `List[dict]` format with all
    relevant fields from the Pydantic model. **annotator alignment issue**.
    For now I opted for choosing the first non-zero value.
    3. Forward the `List[dict]` to a similar flow we previously had.
    4. Also add `dataset.unify_responses(question, Enum(strategy))`-method
    5. Added `*QuestionUnifcation` to schemas to hold logic surrounding
    unifying multiplier responses
    6. Added `client.feedback.training`
    7. Added`TrainingDataFor*` to hold logic surrounding
    `prepare_for_training`-methods per task
    8. Added inheritance for ArgillaTrainer
    
    ```python 
    import argilla as rg
    from argilla import (
        FeedbackRecord,
        LabelQuestion,
        LabelQuestionUnification,
        MultiLabelQuestion,
        TrainingDataForTextClassification,
        ArgillaTrainer
    )
    
    dataset = rg.FeedbackDataset(
        guidelines="Add some guidelines for the annotation team here.",
        fields=[
            rg.TextField(name="text", title="Human prompt"),
        ],
        questions =[
            LabelQuestion(
                name="relevant",
                title="Is the response relevant for the given prompt?",
                labels=["yes","no"],
                required=True,
                visible_labels=None
            ),
            MultiLabelQuestion(
                name="content_class",
                title="Does the response include any of the following?",
                description="Select all that apply",
                labels={"hate": "Hate Speech" , "sexual": "Sexual content", "violent": "Violent content", "pii": "Personal information", "untruthful": "Untruthful info", "not_english": "Not English", "inappropriate": "Inappropriate content"},
                required=False,
                visible_labels=4
            ),
        ]
    )
    dataset.add_records(
        records=[
            FeedbackRecord(
                fields={"text": "What is your favorite color?"},
                responses=[{"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["hate"]}}}]
            ),
            FeedbackRecord(
                fields={"text": "What do you think about the new iPhone?"},
                responses=[{"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate"]}}}]
            ),
            FeedbackRecord(
                fields={"text": "What is your feeling about the technology?"},
                responses=[{"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["sexual"]}}},
                           {"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate", "sexual"]}}},
                           {"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["hate", "sexual"]}}}]
            ),
            FeedbackRecord(
                fields={"text": "Jesus Christ!"},
                responses=[{"values": {"relevant": {"value": "no"}, "content_class": {"value": ["sexual"]}}},
                           {"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate"]}}}]
            )
    
        ]
    )
    
    # print(dataset.question_by_name("relevant").__all_labels__)
    
    label = LabelQuestionUnification(question=dataset.question_by_name("relevant"), strategy="majority")
    training_data = TrainingDataForTextClassification(text=dataset.field_by_name("text"), label=label)
    
    for framework in ["spacy", "transformers", "openai", "spark-nlp"]:
        formatted_data = dataset.prepare_for_training(framework, training_data, fetch_records=False, train_size=0.8)
        print(formatted_data)
    
    trainer = ArgillaTrainer(
        dataset=dataset,
        training_task_mapping=training_task_mapping,
        framework="setfit",
        fetch_records=False
    )
    trainer.train("test")
    ```
    
    Closes #2954
    Closes #3184
    Closes #3152 
    
    **Type of change**
    
    - [X] New feature (non-breaking change which adds functionality)
    - [X] Improvement (change adding some improvement to an existing
    functionality)
    
    **How Has This Been Tested**
    
    - [ ] Test A
    - [ ] Test B
    
    **Checklist**
    
    - [ ] I have merged the original branch into my forked branch
    - [ ] I added relevant documentation
    - [ ] follows the style guidelines of this project
    - [ ] I did a self-review of my code
    - [ ] I made corresponding changes to the documentation
    - [ ] My changes generate no new warnings
    - [ ] I have added tests that prove my fix is effective or that my
    feature works
    - [ ] I have added relevant notes to the CHANGELOG.md file (See
    https://keepachangelog.com/
    
    )
    
    ---------
    
    Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
    Co-authored-by: default avatarAlvaro Bartolome <alvaro@argilla.io>
    c9bfaf28
Loading