import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.eval()


tokenizer.pad_token = tokenizer.eos_token 
eos_token_id = tokenizer.eos_token_id
conversation_history = []
pad_token_id = eos_token_id

def generate_text(prompt, max_length=100):
    conversation_history.append(prompt)
    conversation_input = "\n".join(conversation_history)
    inputs = tokenizer(conversation_input, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
    inputs['attention_mask'] = torch.ones(inputs['input_ids'].shape, dtype=torch.long)
    inputs['pad_token_id'] = tokenizer.pad_token_id  

    with torch.no_grad():
        outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2)
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    conversation_history.append(generated_text)
    return generated_text

def chat_with_agGPT_5():
    print("agGPT-5: Hi! How can I help you today?")
    while True:
        user_input = input("You: ")
        if user_input.lower() == 'exit':
            print("agGPT-5: Goodbye! Talk to you soon.")
            break
        response = generate_text(user_input, max_length=550)
        print("agGPT-5:", response)

chat_with_agGPT_5()
