mirror of
https://github.com/decompme/decomp.me.git
synced 2026-05-03 05:26:43 -05:00
Merge branch 'main' of github.com:ethteck/decomp.me into json-diff
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# This file wraps common django decorators in method_decorator for use with the APIView class
|
||||
|
||||
from django.utils.decorators import method_decorator
|
||||
from rest_framework.response import Response
|
||||
|
||||
from typing import Optional, Callable
|
||||
from datetime import datetime
|
||||
|
||||
def condition(etag_func: Optional[Callable[..., Optional[str]]] = None, last_modified_func: Optional[Callable[..., Optional[datetime]]] = None) -> Callable[..., Callable[..., Response]]:
|
||||
"""
|
||||
Handle Last-Modified and ETag headers.
|
||||
"""
|
||||
from django.views.decorators.http import condition
|
||||
return method_decorator(condition(etag_func, last_modified_func))
|
||||
@@ -14,9 +14,12 @@ if TYPE_CHECKING:
|
||||
class AnonymousUser(auth.models.AnonymousUser):
|
||||
profile: Profile
|
||||
|
||||
class Request(DRFRequest):
|
||||
user: Union[User, AnonymousUser]
|
||||
profile: Profile
|
||||
if TYPE_CHECKING:
|
||||
class Request(DRFRequest):
|
||||
user: Union[User, AnonymousUser]
|
||||
profile: Profile
|
||||
else:
|
||||
Request = DRFRequest
|
||||
|
||||
def disable_csrf(get_response):
|
||||
def middleware(request: HttpRequest):
|
||||
|
||||
@@ -3,7 +3,9 @@ from django.urls import reverse
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
import responses
|
||||
from time import sleep
|
||||
|
||||
from .models import Compilation, Scratch, Profile
|
||||
from .github import GitHubUser
|
||||
@@ -224,3 +226,63 @@ class UserTests(APITestCase):
|
||||
response = self.client.get(f"/api/scratch/{slug}")
|
||||
self.assertEqual(response.json()["scratch"]["owner"]["username"], self.GITHUB_USER["login"])
|
||||
self.assertEqual(response.json()["scratch"]["owner"]["is_you"], True)
|
||||
|
||||
class ScratchDetailTests(APITestCase):
|
||||
def make_nop_scratch(self) -> Scratch:
|
||||
response = self.client.post(reverse("scratch"), {
|
||||
'arch': 'mips',
|
||||
'context': '',
|
||||
'target_asm': "jr $ra\nnop\n",
|
||||
})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
|
||||
scratch = Scratch.objects.first()
|
||||
assert scratch is not None # assert keyword instead of self.assertIsNotNone for mypy
|
||||
return scratch
|
||||
|
||||
def test_404_head(self):
|
||||
"""
|
||||
Ensure that HEAD requests 404 correctly.
|
||||
"""
|
||||
response = self.client.head(reverse("scratch-detail", args=["doesnt_exist"]))
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_last_modified(self):
|
||||
"""
|
||||
Ensure that the Last-Modified header is set.
|
||||
"""
|
||||
|
||||
scratch = self.make_nop_scratch()
|
||||
|
||||
response = self.client.head(reverse("scratch-detail", args=[scratch.slug]))
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assert_(response.headers.get("Last-Modified") is not None)
|
||||
|
||||
def test_if_modified_since(self):
|
||||
"""
|
||||
Ensure that the If-Modified-Since header is handled.
|
||||
"""
|
||||
|
||||
scratch = self.make_nop_scratch()
|
||||
|
||||
response = self.client.head(reverse("scratch-detail", args=[scratch.slug]))
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
# should be unmodified
|
||||
response = self.client.get(reverse("scratch-detail", args=[scratch.slug]), HTTP_IF_MODIFIED_SINCE=last_modified)
|
||||
self.assertEqual(response.status_code, status.HTTP_304_NOT_MODIFIED)
|
||||
|
||||
# Last-Modified is only granular to the second
|
||||
sleep(1)
|
||||
|
||||
# touch the scratch
|
||||
old_last_updated = scratch.last_updated
|
||||
scratch.slug = "newslug"
|
||||
scratch.save()
|
||||
self.assertNotEqual(scratch.last_updated, old_last_updated)
|
||||
|
||||
# should now be modified
|
||||
response = self.client.get(reverse("scratch-detail", args=[scratch.slug]), HTTP_IF_MODIFIED_SINCE=last_modified)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
@@ -4,8 +4,8 @@ from . import views
|
||||
|
||||
urlpatterns = [
|
||||
path('compilers', views.compilers, name='compilers'),
|
||||
path('scratch', views.scratch, name='scratch'), # TODO make this into its own view
|
||||
path('scratch/<slug:slug>', views.scratch, name='scratch-detail'),
|
||||
path('scratch', views.create_scratch, name='scratch'),
|
||||
path('scratch/<slug:slug>', views.ScratchDetail.as_view(), name='scratch-detail'),
|
||||
path('scratch/<slug:slug>/compile', views.compile, name='compile_scratch'),
|
||||
path('scratch/<slug:slug>/fork', views.fork, name='fork_scratch'),
|
||||
path('user', views.CurrentUser.as_view(), name="current-user"),
|
||||
|
||||
+118
-108
@@ -9,16 +9,16 @@ from rest_framework import serializers, status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.decorators import api_view
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .models import Profile, Asm, Scratch
|
||||
from .models import Profile, Asm, Scratch, gen_scratch_id
|
||||
from .github import GitHubUser
|
||||
from .middleware import Request
|
||||
from coreapp.models import gen_scratch_id
|
||||
|
||||
from .decorators.django import condition
|
||||
|
||||
def get_db_asm(request_asm) -> Asm:
|
||||
h = hashlib.sha256(request_asm.encode()).hexdigest()
|
||||
@@ -34,125 +34,135 @@ def compilers(request):
|
||||
"compiler_ids": CompilerWrapper.available_compilers(),
|
||||
})
|
||||
|
||||
class ScratchDetail(APIView):
|
||||
# type-ignored due to python/mypy#7778
|
||||
def scratch_last_modified(request: Request, slug: str) -> Optional[datetime]: # type: ignore
|
||||
scratch: Optional[Scratch] = Scratch.objects.filter(slug=slug).first()
|
||||
if scratch:
|
||||
return scratch.last_updated
|
||||
else:
|
||||
return None
|
||||
|
||||
@api_view(["GET", "POST", "PATCH"])
|
||||
def scratch(request, slug=None):
|
||||
"""
|
||||
Get, create, or update a scratch
|
||||
"""
|
||||
scratch_condition = condition(last_modified_func=scratch_last_modified)
|
||||
|
||||
if request.method == "GET":
|
||||
if not slug:
|
||||
return Response("Missing slug", status=status.HTTP_400_BAD_REQUEST)
|
||||
@scratch_condition
|
||||
def head(self, request: Request, slug: str):
|
||||
get_object_or_404(Scratch, slug=slug) # for 404
|
||||
return Response()
|
||||
|
||||
db_scratch = get_object_or_404(Scratch, slug=slug)
|
||||
@scratch_condition
|
||||
def get(self, request: Request, slug: str):
|
||||
scratch = get_object_or_404(Scratch, slug=slug)
|
||||
|
||||
if not db_scratch.owner:
|
||||
if not scratch.owner and request.query_params.get("no_take_ownership") is None:
|
||||
# Give ownership to this profile
|
||||
profile = request.profile
|
||||
|
||||
logging.debug(f"Granting ownership of scratch {db_scratch} to {profile}")
|
||||
logging.debug(f"Granting ownership of scratch {scratch} to {profile}")
|
||||
|
||||
db_scratch.owner = profile
|
||||
db_scratch.save()
|
||||
scratch.owner = profile
|
||||
scratch.save()
|
||||
|
||||
return Response({
|
||||
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
|
||||
})
|
||||
|
||||
elif request.method == "POST":
|
||||
if slug:
|
||||
return Response({"error": "Not allowed to POST with slug"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
ser = ScratchCreateSerializer(data=request.data)
|
||||
ser.is_valid(raise_exception=True)
|
||||
data = ser.validated_data
|
||||
|
||||
arch = data.get("arch")
|
||||
compiler = data.get("compiler", "")
|
||||
if compiler:
|
||||
arch = CompilerWrapper.arch_from_compiler(compiler)
|
||||
if not arch:
|
||||
raise serializers.ValidationError("Unknown compiler")
|
||||
elif not arch:
|
||||
raise serializers.ValidationError("arch not provided")
|
||||
|
||||
target_asm = data["target_asm"]
|
||||
context = data["context"]
|
||||
|
||||
asm = get_db_asm(target_asm)
|
||||
|
||||
assembly, err = CompilerWrapper.assemble_asm(arch, asm)
|
||||
if not assembly:
|
||||
assert isinstance(err, str)
|
||||
|
||||
errors = []
|
||||
|
||||
for line in err.splitlines():
|
||||
if "asm.s:" in line:
|
||||
errors.append(line[line.find("asm.s:") + len("asm.s:") :].strip())
|
||||
else:
|
||||
errors.append(line)
|
||||
|
||||
return Response({
|
||||
"error": "as_error",
|
||||
"error_description": "Error when assembling target asm",
|
||||
"as_errors": errors,
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
source_code = data.get("source_code")
|
||||
if not source_code:
|
||||
source_code = "void func() {}\n"
|
||||
if arch == "mips":
|
||||
source_code = M2CWrapper.decompile(asm.data, context) or source_code
|
||||
|
||||
cc_opts = data.get("compiler_flags", "")
|
||||
if compiler and cc_opts:
|
||||
cc_opts = CompilerWrapper.filter_cc_opts(compiler, cc_opts)
|
||||
|
||||
scratch_data = {
|
||||
"slug": gen_scratch_id(),
|
||||
"arch": arch,
|
||||
"compiler": compiler,
|
||||
"cc_opts": cc_opts,
|
||||
"context": context,
|
||||
"source_code": source_code,
|
||||
"target_assembly": assembly.pk,
|
||||
response = self.head(request, slug)
|
||||
response.data = {
|
||||
"scratch": ScratchWithMetadataSerializer(scratch, context={ "request": request }).data,
|
||||
}
|
||||
return response
|
||||
|
||||
serializer = ScratchSerializer(data=scratch_data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
db_scratch = Scratch.objects.get(slug=scratch_data["slug"])
|
||||
|
||||
return Response({
|
||||
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
|
||||
elif request.method == "PATCH":
|
||||
if not slug:
|
||||
return Response({"error": "Missing slug"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def patch(self, request: Request, slug: str):
|
||||
required_params = ["compiler", "cc_opts", "source_code", "context"]
|
||||
|
||||
for param in required_params:
|
||||
if param not in request.data:
|
||||
return Response({"error": f"Missing parameter: {param}"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
db_scratch = get_object_or_404(Scratch, slug=slug)
|
||||
scratch = get_object_or_404(Scratch, slug=slug)
|
||||
|
||||
if db_scratch.owner and db_scratch.owner != request.profile:
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
if scratch.owner and scratch.owner != request.profile:
|
||||
response = self.get(request, slug)
|
||||
response.status_code = status.HTTP_403_FORBIDDEN
|
||||
return response
|
||||
|
||||
# TODO validate
|
||||
db_scratch.compiler = request.data["compiler"]
|
||||
db_scratch.cc_opts = request.data["cc_opts"]
|
||||
db_scratch.source_code = request.data["source_code"]
|
||||
db_scratch.context = request.data["context"]
|
||||
db_scratch.save()
|
||||
return Response(status=status.HTTP_202_ACCEPTED)
|
||||
scratch.compiler = request.data["compiler"]
|
||||
scratch.cc_opts = request.data["cc_opts"]
|
||||
scratch.source_code = request.data["source_code"]
|
||||
scratch.context = request.data["context"]
|
||||
scratch.save()
|
||||
|
||||
return self.get(request, slug)
|
||||
|
||||
@api_view(["POST"])
|
||||
def create_scratch(request):
|
||||
"""
|
||||
Create a scratch
|
||||
"""
|
||||
|
||||
ser = ScratchCreateSerializer(data=request.data)
|
||||
ser.is_valid(raise_exception=True)
|
||||
data = ser.validated_data
|
||||
|
||||
arch = data.get("arch")
|
||||
compiler = data.get("compiler", "")
|
||||
if compiler:
|
||||
arch = CompilerWrapper.arch_from_compiler(compiler)
|
||||
if not arch:
|
||||
raise serializers.ValidationError("Unknown compiler")
|
||||
elif not arch:
|
||||
raise serializers.ValidationError("arch not provided")
|
||||
|
||||
target_asm = data["target_asm"]
|
||||
context = data["context"]
|
||||
|
||||
asm = get_db_asm(target_asm)
|
||||
|
||||
assembly, err = CompilerWrapper.assemble_asm(arch, asm)
|
||||
if not assembly:
|
||||
assert isinstance(err, str)
|
||||
|
||||
errors = []
|
||||
|
||||
for line in err.splitlines():
|
||||
if "asm.s:" in line:
|
||||
errors.append(line[line.find("asm.s:") + len("asm.s:") :].strip())
|
||||
else:
|
||||
errors.append(line)
|
||||
|
||||
return Response({
|
||||
"error": "as_error",
|
||||
"error_description": "Error when assembling target asm",
|
||||
"as_errors": errors,
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
source_code = data.get("source_code")
|
||||
if not source_code:
|
||||
source_code = "void func() {}\n"
|
||||
if arch == "mips":
|
||||
source_code = M2CWrapper.decompile(asm.data, context) or source_code
|
||||
|
||||
cc_opts = data.get("compiler_flags", "")
|
||||
if compiler and cc_opts:
|
||||
cc_opts = CompilerWrapper.filter_cc_opts(compiler, cc_opts)
|
||||
|
||||
scratch_data = {
|
||||
"slug": gen_scratch_id(),
|
||||
"arch": arch,
|
||||
"compiler": compiler,
|
||||
"cc_opts": cc_opts,
|
||||
"context": context,
|
||||
"source_code": source_code,
|
||||
"target_assembly": assembly.pk,
|
||||
}
|
||||
|
||||
serializer = ScratchSerializer(data=scratch_data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
db_scratch = Scratch.objects.get(slug=scratch_data["slug"])
|
||||
|
||||
return Response({
|
||||
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
|
||||
@api_view(["POST"])
|
||||
def compile(request, slug):
|
||||
@@ -181,12 +191,12 @@ def compile(request, slug):
|
||||
if compilation:
|
||||
diff_output = AsmDifferWrapper.diff(scratch.target_assembly, compilation)
|
||||
|
||||
response_obj = {
|
||||
"diff_output": diff_output,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
return Response(response_obj)
|
||||
return Response({
|
||||
"compilation": {
|
||||
"diff_output": diff_output,
|
||||
"errors": errors,
|
||||
},
|
||||
})
|
||||
|
||||
@api_view(["POST"])
|
||||
def fork(request, slug):
|
||||
|
||||
Reference in New Issue
Block a user