Spaces:
Runtime error
Runtime error
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| from datasets import load_dataset | |
| from transformers import HfArgumentParser | |
| class ScriptArguments: | |
| r""" | |
| Arguments for the script. | |
| Args: | |
| push_to_hub (`bool`, *optional*, defaults to `False`): | |
| Whether to push the dataset to the Hugging Face Hub. | |
| repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`): | |
| Hugging Face repository ID to push the dataset to. | |
| dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): | |
| Number of workers to use for dataset processing. | |
| """ | |
| push_to_hub: bool = False | |
| repo_id: str = "trl-lib/hh-rlhf-helpful-base" | |
| dataset_num_proc: Optional[int] = None | |
| def common_start(str1: str, str2: str) -> str: | |
| # Zip the two strings and iterate over them together | |
| common_chars = [] | |
| for c1, c2 in zip(str1, str2): | |
| if c1 == c2: | |
| common_chars.append(c1) | |
| else: | |
| break | |
| # Join the common characters and return as a string | |
| return "".join(common_chars) | |
| def extract_dialogue(example: str) -> List[Dict[str, str]]: | |
| # Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues | |
| prompt_text = common_start(example["chosen"], example["rejected"]) | |
| # The chosen and rejected may share a common start, so we need to remove the common part | |
| if not prompt_text.endswith("\n\nAssistant: "): | |
| prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: " | |
| # Extract the chosen and rejected lines | |
| chosen_line = example["chosen"][len(prompt_text) :] | |
| rejected_line = example["rejected"][len(prompt_text) :] | |
| # Remove the generation prompt ("\n\nAssistant: ") from the prompt | |
| prompt_text = prompt_text[: -len("\n\nAssistant: ")] | |
| # Split the string at every occurrence of "Human: " or "Assistant: " | |
| prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text) | |
| # Remove the first element as it's empty | |
| prompt_lines = prompt_lines[1:] | |
| prompt = [] | |
| for idx in range(0, len(prompt_lines), 2): | |
| role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant" | |
| content = prompt_lines[idx + 1] | |
| prompt.append({"role": role, "content": content}) | |
| # Remove the prompt from the chosen and rejected dialogues | |
| chosen = [{"role": "assitant", "content": chosen_line}] | |
| rejected = [{"role": "assistant", "content": rejected_line}] | |
| return {"prompt": prompt, "chosen": chosen, "rejected": rejected} | |
| def runner(arguments): | |
| parser = HfArgumentParser(arguments) | |
| script_args = parser.parse_args_into_dataclasses()[0] | |
| dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base") | |
| dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc) | |
| return | |
| # if script_args.push_to_hub: | |
| # dataset.push_to_hub(script_args.repo_id) | |