import re
import random
import math
from collections import defaultdict, Counter
from RT import dataset  

class AgGPT2:
    def __init__(self, dataset):
        self.dataset = dataset
        self.word_freq = defaultdict(int)
        self.intent_words = defaultdict(set)
        self.idf = defaultdict(float)
        self.intent_vectors = {}
        self.context = []
        self.train()

    def preprocess(self, text):
        tokens = re.findall(r'\b\w+\b', text.lower())
        stems = [self.porter_stem(token) for token in tokens]
        return stems

    def porter_stem(self, word):
        if len(word) <= 2:
            return word
        
        vowels = 'aeiou'
        
        def measure(word):
            word = re.sub(r'[^aeiouy]$', '', word)
            return len(re.findall(r'[aeiouy][^aeiouy]', word))
        
        def ends_with_double_consonant(word):
            return len(word) >= 2 and word[-1] == word[-2] and word[-1] not in vowels
        
        def ends_with_cvc(word):
            return len(word) >= 3 and word[-1] not in vowels and word[-2] in vowels and word[-3] not in vowels and word[-1] not in 'wxy'

        if word.endswith('sses'):
            word = word[:-2]
        elif word.endswith('ies'):
            word = word[:-2]
        elif word.endswith('ss'):
            pass
        elif word.endswith('s'):
            word = word[:-1]
        
        if word.endswith('eed'):
            if measure(word[:-3]) > 0:
                word = word[:-1]
        elif word.endswith('ed'):
            if 'v' in word[:-2]:
                word = word[:-2]
                if word.endswith('at') or word.endswith('bl') or word.endswith('iz'):
                    word += 'e'
                elif ends_with_double_consonant(word) and not word.endswith('l') and not word.endswith('s') and not word.endswith('z'):
                    word = word[:-1]
                elif measure(word) == 1 and ends_with_cvc(word):
                    word += 'e'
        elif word.endswith('ing'):
            if 'v' in word[:-3]:
                word = word[:-3]
                if word.endswith('at') or word.endswith('bl') or word.endswith('iz'):
                    word += 'e'
                elif ends_with_double_consonant(word) and not word.endswith('l') and not word.endswith('s') and not word.endswith('z'):
                    word = word[:-1]
                elif measure(word) == 1 and ends_with_cvc(word):
                    word += 'e'
        
        if word.endswith('y') and 'v' in word[:-1]:
            word = word[:-1] + 'i'
        
        return word

    def train(self):
        doc_count = len(self.dataset)
        
        for intent, phrases in self.dataset.items():
            for phrase in phrases:
                words = self.preprocess(phrase)
                for word in set(words):
                    self.word_freq[word] += 1
                    self.intent_words[intent].add(word)
        
        for word, freq in self.word_freq.items():
            self.idf[word] = math.log((doc_count + 1) / (freq + 1)) + 1
        
        for intent, words in self.intent_words.items():
            self.intent_vectors[intent] = self.create_vector(words)

    def create_vector(self, words):
        vector = {}
        word_counts = Counter(words)
        for word, count in word_counts.items():
            tf = count / len(words)
            vector[word] = tf * self.idf[word]
        return vector

    def cosine_similarity(self, vec1, vec2):
        intersection = set(vec1.keys()) & set(vec2.keys())
        numerator = sum([vec1[x] * vec2[x] for x in intersection])
        
        sum1 = sum([vec1[x]**2 for x in vec1.keys()])
        sum2 = sum([vec2[x]**2 for x in vec2.keys()])
        denominator = math.sqrt(sum1) * math.sqrt(sum2)
        
        if not denominator:
            return 0.0
        return numerator / denominator

    def get_intent(self, user_input):
        input_vector = self.create_vector(self.preprocess(user_input))
        
        best_intent = None
        best_score = 0

        for intent, intent_vector in self.intent_vectors.items():
            similarity = self.cosine_similarity(input_vector, intent_vector)
            if similarity > best_score:
                best_score = similarity
                best_intent = intent

        return best_intent if best_score > 0.1 else 'unknown'

    def generate_response(self, user_input):
        intent = self.get_intent(user_input)
        if intent == 'unknown':
            return self.handle_unknown_input(user_input)
        
        responses = self.dataset[intent]
        response = self.fill_template(random.choice(responses), user_input)
        return self.apply_context(response, user_input)

    def handle_unknown_input(self, user_input):
        words = self.preprocess(user_input)
        relevant_intents = [intent for intent, intent_words in self.intent_words.items() if any(word in intent_words for word in words)]
        
        if relevant_intents:
            intent = random.choice(relevant_intents)
            return f"I'm not sure I fully understand, but it sounds like you might be talking about {intent}. " + random.choice(self.dataset[intent])
        else:
            return "I'm not sure how to respond to that. Can you rephrase or ask something else?"

    def fill_template(self, template, user_input):
        words = self.preprocess(user_input)
        for i, word in enumerate(words):
            placeholder = '{' + str(i) + '}'
            if placeholder in template:
                template = template.replace(placeholder, word)
        return template

    def apply_context(self, response, user_input):
        if len(self.context) >= 2:
            prev_intent = self.get_intent(self.context[-2])
            curr_intent = self.get_intent(user_input)
            if prev_intent != curr_intent and prev_intent != 'unknown':
                response += f" By the way, were you still interested in discussing {prev_intent}?"
        return response

    def save_user_input(self, user_input):
        try:
            with open('learn.txt', 'a') as f:
                f.write(user_input + '\n')
        except IOError:
            print("Error saving user input.")

    def conversation(self):
        print("AgGPT-2: Hello! I'm an AI assistant. How can I help you today?")
        
        while True:
            user_input = input("You: ").strip()
            
            if user_input.lower() in ['exit', 'quit', 'bye']:
                print("AgGPT-2: Goodbye! It was nice chatting with you.")
                break
            
            self.save_user_input(user_input)
            
            self.context.append(user_input)
            self.context = self.context[-3:]
            
            response = self.generate_response(user_input)
            print(f"AgGPT-2: {response}")

if __name__ == "__main__":
    chatbot = AgGPT2(dataset)
    chatbot.conversation()
